Unverified Commit c1079f08 authored by Daniel Socek's avatar Daniel Socek Committed by GitHub
Browse files

Fix textual inversion SDXL and add support for 2nd text encoder (#9010)



* Fix textual inversion SDXL and add support for 2nd text encoder
Signed-off-by: default avatarDaniel Socek <daniel.socek@intel.com>

* Fix style/quality of text inv for sdxl
Signed-off-by: default avatarDaniel Socek <daniel.socek@intel.com>

---------
Signed-off-by: default avatarDaniel Socek <daniel.socek@intel.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 65e30907
...@@ -23,4 +23,25 @@ accelerate launch textual_inversion_sdxl.py \ ...@@ -23,4 +23,25 @@ accelerate launch textual_inversion_sdxl.py \
--output_dir="./textual_inversion_cat_sdxl" --output_dir="./textual_inversion_cat_sdxl"
``` ```
For now, only training of the first text encoder is supported. Training of both text encoders is supported.
\ No newline at end of file
### Inference Example
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionXLPipeline`.
Make sure to include the `placeholder_token` in your prompt.
```python
from diffusers import StableDiffusionXLPipeline
import torch
model_id = "./textual_inversion_cat_sdxl"
pipe = StableDiffusionXLPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to("cuda")
prompt = "A <cat-toy> backpack"
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
image.save("cat-backpack.png")
image = pipe(prompt="", prompt_2=prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
image.save("cat-backpack-prompt_2.png")
```
...@@ -135,7 +135,7 @@ def log_validation( ...@@ -135,7 +135,7 @@ def log_validation(
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
text_encoder=accelerator.unwrap_model(text_encoder_1), text_encoder=accelerator.unwrap_model(text_encoder_1),
text_encoder_2=text_encoder_2, text_encoder_2=accelerator.unwrap_model(text_encoder_2),
tokenizer=tokenizer_1, tokenizer=tokenizer_1,
tokenizer_2=tokenizer_2, tokenizer_2=tokenizer_2,
unet=unet, unet=unet,
...@@ -678,36 +678,54 @@ def main(): ...@@ -678,36 +678,54 @@ def main():
f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
" `placeholder_token` that is not already in the tokenizer." " `placeholder_token` that is not already in the tokenizer."
) )
num_added_tokens = tokenizer_2.add_tokens(placeholder_tokens)
if num_added_tokens != args.num_vectors:
raise ValueError(
f"The 2nd tokenizer already contains the token {args.placeholder_token}. Please pass a different"
" `placeholder_token` that is not already in the tokenizer."
)
# Convert the initializer_token, placeholder_token to ids # Convert the initializer_token, placeholder_token to ids
token_ids = tokenizer_1.encode(args.initializer_token, add_special_tokens=False) token_ids = tokenizer_1.encode(args.initializer_token, add_special_tokens=False)
token_ids_2 = tokenizer_2.encode(args.initializer_token, add_special_tokens=False)
# Check if initializer_token is a single token or a sequence of tokens # Check if initializer_token is a single token or a sequence of tokens
if len(token_ids) > 1: if len(token_ids) > 1 or len(token_ids_2) > 1:
raise ValueError("The initializer token must be a single token.") raise ValueError("The initializer token must be a single token.")
initializer_token_id = token_ids[0] initializer_token_id = token_ids[0]
placeholder_token_ids = tokenizer_1.convert_tokens_to_ids(placeholder_tokens) placeholder_token_ids = tokenizer_1.convert_tokens_to_ids(placeholder_tokens)
initializer_token_id_2 = token_ids_2[0]
placeholder_token_ids_2 = tokenizer_2.convert_tokens_to_ids(placeholder_tokens)
# Resize the token embeddings as we are adding new special tokens to the tokenizer # Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder_1.resize_token_embeddings(len(tokenizer_1)) text_encoder_1.resize_token_embeddings(len(tokenizer_1))
text_encoder_2.resize_token_embeddings(len(tokenizer_2))
# Initialise the newly added placeholder token with the embeddings of the initializer token # Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = text_encoder_1.get_input_embeddings().weight.data token_embeds = text_encoder_1.get_input_embeddings().weight.data
token_embeds_2 = text_encoder_2.get_input_embeddings().weight.data
with torch.no_grad(): with torch.no_grad():
for token_id in placeholder_token_ids: for token_id in placeholder_token_ids:
token_embeds[token_id] = token_embeds[initializer_token_id].clone() token_embeds[token_id] = token_embeds[initializer_token_id].clone()
for token_id in placeholder_token_ids_2:
token_embeds_2[token_id] = token_embeds_2[initializer_token_id_2].clone()
# Freeze vae and unet # Freeze vae and unet
vae.requires_grad_(False) vae.requires_grad_(False)
unet.requires_grad_(False) unet.requires_grad_(False)
text_encoder_2.requires_grad_(False)
# Freeze all parameters except for the token embeddings in text encoder # Freeze all parameters except for the token embeddings in text encoder
text_encoder_1.text_model.encoder.requires_grad_(False) text_encoder_1.text_model.encoder.requires_grad_(False)
text_encoder_1.text_model.final_layer_norm.requires_grad_(False) text_encoder_1.text_model.final_layer_norm.requires_grad_(False)
text_encoder_1.text_model.embeddings.position_embedding.requires_grad_(False) text_encoder_1.text_model.embeddings.position_embedding.requires_grad_(False)
text_encoder_2.text_model.encoder.requires_grad_(False)
text_encoder_2.text_model.final_layer_norm.requires_grad_(False)
text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False)
if args.gradient_checkpointing: if args.gradient_checkpointing:
text_encoder_1.gradient_checkpointing_enable() text_encoder_1.gradient_checkpointing_enable()
text_encoder_2.gradient_checkpointing_enable()
if args.enable_xformers_memory_efficient_attention: if args.enable_xformers_memory_efficient_attention:
if is_xformers_available(): if is_xformers_available():
...@@ -746,7 +764,11 @@ def main(): ...@@ -746,7 +764,11 @@ def main():
optimizer_class = torch.optim.AdamW optimizer_class = torch.optim.AdamW
optimizer = optimizer_class( optimizer = optimizer_class(
text_encoder_1.get_input_embeddings().parameters(), # only optimize the embeddings # only optimize the embeddings
[
text_encoder_1.text_model.embeddings.token_embedding.weight,
text_encoder_2.text_model.embeddings.token_embedding.weight,
],
lr=args.learning_rate, lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2), betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay, weight_decay=args.adam_weight_decay,
...@@ -786,9 +808,10 @@ def main(): ...@@ -786,9 +808,10 @@ def main():
) )
text_encoder_1.train() text_encoder_1.train()
text_encoder_2.train()
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.
text_encoder_1, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( text_encoder_1, text_encoder_2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder_1, optimizer, train_dataloader, lr_scheduler text_encoder_1, text_encoder_2, optimizer, train_dataloader, lr_scheduler
) )
# For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
...@@ -866,11 +889,13 @@ def main(): ...@@ -866,11 +889,13 @@ def main():
# keep original embeddings as reference # keep original embeddings as reference
orig_embeds_params = accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight.data.clone() orig_embeds_params = accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight.data.clone()
orig_embeds_params_2 = accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight.data.clone()
for epoch in range(first_epoch, args.num_train_epochs): for epoch in range(first_epoch, args.num_train_epochs):
text_encoder_1.train() text_encoder_1.train()
text_encoder_2.train()
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(text_encoder_1): with accelerator.accumulate([text_encoder_1, text_encoder_2]):
# Convert images to latent space # Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach() latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
latents = latents * vae.config.scaling_factor latents = latents * vae.config.scaling_factor
...@@ -892,9 +917,7 @@ def main(): ...@@ -892,9 +917,7 @@ def main():
.hidden_states[-2] .hidden_states[-2]
.to(dtype=weight_dtype) .to(dtype=weight_dtype)
) )
encoder_output_2 = text_encoder_2( encoder_output_2 = text_encoder_2(batch["input_ids_2"], output_hidden_states=True)
batch["input_ids_2"].reshape(batch["input_ids_1"].shape[0], -1), output_hidden_states=True
)
encoder_hidden_states_2 = encoder_output_2.hidden_states[-2].to(dtype=weight_dtype) encoder_hidden_states_2 = encoder_output_2.hidden_states[-2].to(dtype=weight_dtype)
original_size = [ original_size = [
(batch["original_size"][0][i].item(), batch["original_size"][1][i].item()) (batch["original_size"][0][i].item(), batch["original_size"][1][i].item())
...@@ -938,11 +961,16 @@ def main(): ...@@ -938,11 +961,16 @@ def main():
# Let's make sure we don't update any embedding weights besides the newly added token # Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = torch.ones((len(tokenizer_1),), dtype=torch.bool) index_no_updates = torch.ones((len(tokenizer_1),), dtype=torch.bool)
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
index_no_updates_2 = torch.ones((len(tokenizer_2),), dtype=torch.bool)
index_no_updates_2[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1] = False
with torch.no_grad(): with torch.no_grad():
accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[ accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[
index_no_updates index_no_updates
] = orig_embeds_params[index_no_updates] ] = orig_embeds_params[index_no_updates]
accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[
index_no_updates_2
] = orig_embeds_params_2[index_no_updates_2]
# Checks if the accelerator has performed an optimization step behind the scenes # Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: if accelerator.sync_gradients:
...@@ -960,6 +988,16 @@ def main(): ...@@ -960,6 +988,16 @@ def main():
save_path, save_path,
safe_serialization=True, safe_serialization=True,
) )
weight_name = f"learned_embeds_2-steps-{global_step}.safetensors"
save_path = os.path.join(args.output_dir, weight_name)
save_progress(
text_encoder_2,
placeholder_token_ids_2,
accelerator,
args,
save_path,
safe_serialization=True,
)
if accelerator.is_main_process: if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0: if global_step % args.checkpointing_steps == 0:
...@@ -1034,7 +1072,7 @@ def main(): ...@@ -1034,7 +1072,7 @@ def main():
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
text_encoder=accelerator.unwrap_model(text_encoder_1), text_encoder=accelerator.unwrap_model(text_encoder_1),
text_encoder_2=text_encoder_2, text_encoder_2=accelerator.unwrap_model(text_encoder_2),
vae=vae, vae=vae,
unet=unet, unet=unet,
tokenizer=tokenizer_1, tokenizer=tokenizer_1,
...@@ -1052,6 +1090,16 @@ def main(): ...@@ -1052,6 +1090,16 @@ def main():
save_path, save_path,
safe_serialization=True, safe_serialization=True,
) )
weight_name = "learned_embeds_2.safetensors"
save_path = os.path.join(args.output_dir, weight_name)
save_progress(
text_encoder_2,
placeholder_token_ids_2,
accelerator,
args,
save_path,
safe_serialization=True,
)
if args.push_to_hub: if args.push_to_hub:
save_model_card( save_model_card(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment