"...en/git@developer.sourcefind.cn:chenpangpang/diffusers.git" did not exist on "92e1164e2e4e96b8fa9ebb5ea8344f2188ea7b86"
Unverified Commit c1079f08 authored by Daniel Socek's avatar Daniel Socek Committed by GitHub
Browse files

Fix textual inversion SDXL and add support for 2nd text encoder (#9010)



* Fix textual inversion SDXL and add support for 2nd text encoder
Signed-off-by: default avatarDaniel Socek <daniel.socek@intel.com>

* Fix style/quality of text inv for sdxl
Signed-off-by: default avatarDaniel Socek <daniel.socek@intel.com>

---------
Signed-off-by: default avatarDaniel Socek <daniel.socek@intel.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 65e30907
...@@ -23,4 +23,25 @@ accelerate launch textual_inversion_sdxl.py \ ...@@ -23,4 +23,25 @@ accelerate launch textual_inversion_sdxl.py \
--output_dir="./textual_inversion_cat_sdxl" --output_dir="./textual_inversion_cat_sdxl"
``` ```
For now, only training of the first text encoder is supported. Training of both text encoders is supported.
\ No newline at end of file
### Inference Example
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionXLPipeline`.
Make sure to include the `placeholder_token` in your prompt.
```python
from diffusers import StableDiffusionXLPipeline
import torch
model_id = "./textual_inversion_cat_sdxl"
pipe = StableDiffusionXLPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to("cuda")
prompt = "A <cat-toy> backpack"
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
image.save("cat-backpack.png")
image = pipe(prompt="", prompt_2=prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
image.save("cat-backpack-prompt_2.png")
```
...@@ -135,7 +135,7 @@ def log_validation( ...@@ -135,7 +135,7 @@ def log_validation(
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
text_encoder=accelerator.unwrap_model(text_encoder_1), text_encoder=accelerator.unwrap_model(text_encoder_1),
text_encoder_2=text_encoder_2, text_encoder_2=accelerator.unwrap_model(text_encoder_2),
tokenizer=tokenizer_1, tokenizer=tokenizer_1,
tokenizer_2=tokenizer_2, tokenizer_2=tokenizer_2,
unet=unet, unet=unet,
...@@ -678,36 +678,54 @@ def main(): ...@@ -678,36 +678,54 @@ def main():
f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
" `placeholder_token` that is not already in the tokenizer." " `placeholder_token` that is not already in the tokenizer."
) )
num_added_tokens = tokenizer_2.add_tokens(placeholder_tokens)
if num_added_tokens != args.num_vectors:
raise ValueError(
f"The 2nd tokenizer already contains the token {args.placeholder_token}. Please pass a different"
" `placeholder_token` that is not already in the tokenizer."
)
# Convert the initializer_token, placeholder_token to ids # Convert the initializer_token, placeholder_token to ids
token_ids = tokenizer_1.encode(args.initializer_token, add_special_tokens=False) token_ids = tokenizer_1.encode(args.initializer_token, add_special_tokens=False)
token_ids_2 = tokenizer_2.encode(args.initializer_token, add_special_tokens=False)
# Check if initializer_token is a single token or a sequence of tokens # Check if initializer_token is a single token or a sequence of tokens
if len(token_ids) > 1: if len(token_ids) > 1 or len(token_ids_2) > 1:
raise ValueError("The initializer token must be a single token.") raise ValueError("The initializer token must be a single token.")
initializer_token_id = token_ids[0] initializer_token_id = token_ids[0]
placeholder_token_ids = tokenizer_1.convert_tokens_to_ids(placeholder_tokens) placeholder_token_ids = tokenizer_1.convert_tokens_to_ids(placeholder_tokens)
initializer_token_id_2 = token_ids_2[0]
placeholder_token_ids_2 = tokenizer_2.convert_tokens_to_ids(placeholder_tokens)
# Resize the token embeddings as we are adding new special tokens to the tokenizer # Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder_1.resize_token_embeddings(len(tokenizer_1)) text_encoder_1.resize_token_embeddings(len(tokenizer_1))
text_encoder_2.resize_token_embeddings(len(tokenizer_2))
# Initialise the newly added placeholder token with the embeddings of the initializer token # Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = text_encoder_1.get_input_embeddings().weight.data token_embeds = text_encoder_1.get_input_embeddings().weight.data
token_embeds_2 = text_encoder_2.get_input_embeddings().weight.data
with torch.no_grad(): with torch.no_grad():
for token_id in placeholder_token_ids: for token_id in placeholder_token_ids:
token_embeds[token_id] = token_embeds[initializer_token_id].clone() token_embeds[token_id] = token_embeds[initializer_token_id].clone()
for token_id in placeholder_token_ids_2:
token_embeds_2[token_id] = token_embeds_2[initializer_token_id_2].clone()
# Freeze vae and unet # Freeze vae and unet
vae.requires_grad_(False) vae.requires_grad_(False)
unet.requires_grad_(False) unet.requires_grad_(False)
text_encoder_2.requires_grad_(False)
# Freeze all parameters except for the token embeddings in text encoder # Freeze all parameters except for the token embeddings in text encoder
text_encoder_1.text_model.encoder.requires_grad_(False) text_encoder_1.text_model.encoder.requires_grad_(False)
text_encoder_1.text_model.final_layer_norm.requires_grad_(False) text_encoder_1.text_model.final_layer_norm.requires_grad_(False)
text_encoder_1.text_model.embeddings.position_embedding.requires_grad_(False) text_encoder_1.text_model.embeddings.position_embedding.requires_grad_(False)
text_encoder_2.text_model.encoder.requires_grad_(False)
text_encoder_2.text_model.final_layer_norm.requires_grad_(False)
text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False)
if args.gradient_checkpointing: if args.gradient_checkpointing:
text_encoder_1.gradient_checkpointing_enable() text_encoder_1.gradient_checkpointing_enable()
text_encoder_2.gradient_checkpointing_enable()
if args.enable_xformers_memory_efficient_attention: if args.enable_xformers_memory_efficient_attention:
if is_xformers_available(): if is_xformers_available():
...@@ -746,7 +764,11 @@ def main(): ...@@ -746,7 +764,11 @@ def main():
optimizer_class = torch.optim.AdamW optimizer_class = torch.optim.AdamW
optimizer = optimizer_class( optimizer = optimizer_class(
text_encoder_1.get_input_embeddings().parameters(), # only optimize the embeddings # only optimize the embeddings
[
text_encoder_1.text_model.embeddings.token_embedding.weight,
text_encoder_2.text_model.embeddings.token_embedding.weight,
],
lr=args.learning_rate, lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2), betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay, weight_decay=args.adam_weight_decay,
...@@ -786,9 +808,10 @@ def main(): ...@@ -786,9 +808,10 @@ def main():
) )
text_encoder_1.train() text_encoder_1.train()
text_encoder_2.train()
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.
text_encoder_1, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( text_encoder_1, text_encoder_2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder_1, optimizer, train_dataloader, lr_scheduler text_encoder_1, text_encoder_2, optimizer, train_dataloader, lr_scheduler
) )
# For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
...@@ -866,11 +889,13 @@ def main(): ...@@ -866,11 +889,13 @@ def main():
# keep original embeddings as reference # keep original embeddings as reference
orig_embeds_params = accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight.data.clone() orig_embeds_params = accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight.data.clone()
orig_embeds_params_2 = accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight.data.clone()
for epoch in range(first_epoch, args.num_train_epochs): for epoch in range(first_epoch, args.num_train_epochs):
text_encoder_1.train() text_encoder_1.train()
text_encoder_2.train()
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(text_encoder_1): with accelerator.accumulate([text_encoder_1, text_encoder_2]):
# Convert images to latent space # Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach() latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
latents = latents * vae.config.scaling_factor latents = latents * vae.config.scaling_factor
...@@ -892,9 +917,7 @@ def main(): ...@@ -892,9 +917,7 @@ def main():
.hidden_states[-2] .hidden_states[-2]
.to(dtype=weight_dtype) .to(dtype=weight_dtype)
) )
encoder_output_2 = text_encoder_2( encoder_output_2 = text_encoder_2(batch["input_ids_2"], output_hidden_states=True)
batch["input_ids_2"].reshape(batch["input_ids_1"].shape[0], -1), output_hidden_states=True
)
encoder_hidden_states_2 = encoder_output_2.hidden_states[-2].to(dtype=weight_dtype) encoder_hidden_states_2 = encoder_output_2.hidden_states[-2].to(dtype=weight_dtype)
original_size = [ original_size = [
(batch["original_size"][0][i].item(), batch["original_size"][1][i].item()) (batch["original_size"][0][i].item(), batch["original_size"][1][i].item())
...@@ -938,11 +961,16 @@ def main(): ...@@ -938,11 +961,16 @@ def main():
# Let's make sure we don't update any embedding weights besides the newly added token # Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = torch.ones((len(tokenizer_1),), dtype=torch.bool) index_no_updates = torch.ones((len(tokenizer_1),), dtype=torch.bool)
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
index_no_updates_2 = torch.ones((len(tokenizer_2),), dtype=torch.bool)
index_no_updates_2[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1] = False
with torch.no_grad(): with torch.no_grad():
accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[ accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[
index_no_updates index_no_updates
] = orig_embeds_params[index_no_updates] ] = orig_embeds_params[index_no_updates]
accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[
index_no_updates_2
] = orig_embeds_params_2[index_no_updates_2]
# Checks if the accelerator has performed an optimization step behind the scenes # Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: if accelerator.sync_gradients:
...@@ -960,6 +988,16 @@ def main(): ...@@ -960,6 +988,16 @@ def main():
save_path, save_path,
safe_serialization=True, safe_serialization=True,
) )
weight_name = f"learned_embeds_2-steps-{global_step}.safetensors"
save_path = os.path.join(args.output_dir, weight_name)
save_progress(
text_encoder_2,
placeholder_token_ids_2,
accelerator,
args,
save_path,
safe_serialization=True,
)
if accelerator.is_main_process: if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0: if global_step % args.checkpointing_steps == 0:
...@@ -1034,7 +1072,7 @@ def main(): ...@@ -1034,7 +1072,7 @@ def main():
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
text_encoder=accelerator.unwrap_model(text_encoder_1), text_encoder=accelerator.unwrap_model(text_encoder_1),
text_encoder_2=text_encoder_2, text_encoder_2=accelerator.unwrap_model(text_encoder_2),
vae=vae, vae=vae,
unet=unet, unet=unet,
tokenizer=tokenizer_1, tokenizer=tokenizer_1,
...@@ -1052,6 +1090,16 @@ def main(): ...@@ -1052,6 +1090,16 @@ def main():
save_path, save_path,
safe_serialization=True, safe_serialization=True,
) )
weight_name = "learned_embeds_2.safetensors"
save_path = os.path.join(args.output_dir, weight_name)
save_progress(
text_encoder_2,
placeholder_token_ids_2,
accelerator,
args,
save_path,
safe_serialization=True,
)
if args.push_to_hub: if args.push_to_hub:
save_model_card( save_model_card(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment