"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "4408047ac557a10244413f5f650c2da7106567a1"
Unverified Commit beb1c017 authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[advanced dreambooth lora] add clip_skip arg (#8715)



* add clip_skip

* style

* smol fix

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 06ee4db3
...@@ -573,6 +573,13 @@ def parse_args(input_args=None): ...@@ -573,6 +573,13 @@ def parse_args(input_args=None):
default=1e-4, default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.", help="Initial learning rate (after the potential warmup period) to use.",
) )
parser.add_argument(
"--clip_skip",
type=int,
default=None,
help="Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that "
"the output of the pre-final layer will be used for computing the prompt embeddings.",
)
parser.add_argument( parser.add_argument(
"--text_encoder_lr", "--text_encoder_lr",
...@@ -1236,7 +1243,7 @@ def tokenize_prompt(tokenizer, prompt, add_special_tokens=False): ...@@ -1236,7 +1243,7 @@ def tokenize_prompt(tokenizer, prompt, add_special_tokens=False):
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None): def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None, clip_skip=None):
prompt_embeds_list = [] prompt_embeds_list = []
for i, text_encoder in enumerate(text_encoders): for i, text_encoder in enumerate(text_encoders):
...@@ -1253,7 +1260,11 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None): ...@@ -1253,7 +1260,11 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
# We are only ALWAYS interested in the pooled output of the final text encoder # We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0] pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds[-1][-2] if clip_skip is None:
prompt_embeds = prompt_embeds[-1][-2]
else:
# "2" because SDXL always indexes from the penultimate layer.
prompt_embeds = prompt_embeds[-1][-(clip_skip + 2)]
bs_embed, seq_len, _ = prompt_embeds.shape bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
prompt_embeds_list.append(prompt_embeds) prompt_embeds_list.append(prompt_embeds)
...@@ -1830,9 +1841,9 @@ def main(args): ...@@ -1830,9 +1841,9 @@ def main(args):
tokenizers = [tokenizer_one, tokenizer_two] tokenizers = [tokenizer_one, tokenizer_two]
text_encoders = [text_encoder_one, text_encoder_two] text_encoders = [text_encoder_one, text_encoder_two]
def compute_text_embeddings(prompt, text_encoders, tokenizers): def compute_text_embeddings(prompt, text_encoders, tokenizers, clip_skip):
with torch.no_grad(): with torch.no_grad():
prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt) prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt, clip_skip)
prompt_embeds = prompt_embeds.to(accelerator.device) prompt_embeds = prompt_embeds.to(accelerator.device)
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
return prompt_embeds, pooled_prompt_embeds return prompt_embeds, pooled_prompt_embeds
...@@ -1842,7 +1853,7 @@ def main(args): ...@@ -1842,7 +1853,7 @@ def main(args):
# the redundant encoding. # the redundant encoding.
if freeze_text_encoder and not train_dataset.custom_instance_prompts: if freeze_text_encoder and not train_dataset.custom_instance_prompts:
instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings( instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
args.instance_prompt, text_encoders, tokenizers args.instance_prompt, text_encoders, tokenizers, args.clip_skip
) )
# Handle class prompt for prior-preservation. # Handle class prompt for prior-preservation.
...@@ -2052,7 +2063,7 @@ def main(args): ...@@ -2052,7 +2063,7 @@ def main(args):
if train_dataset.custom_instance_prompts: if train_dataset.custom_instance_prompts:
if freeze_text_encoder: if freeze_text_encoder:
prompt_embeds, unet_add_text_embeds = compute_text_embeddings( prompt_embeds, unet_add_text_embeds = compute_text_embeddings(
prompts, text_encoders, tokenizers prompts, text_encoders, tokenizers, args.clip_skip
) )
else: else:
...@@ -2147,6 +2158,7 @@ def main(args): ...@@ -2147,6 +2158,7 @@ def main(args):
tokenizers=None, tokenizers=None,
prompt=None, prompt=None,
text_input_ids_list=[tokens_one, tokens_two], text_input_ids_list=[tokens_one, tokens_two],
clip_skip=args.clip_skip,
) )
unet_added_conditions.update( unet_added_conditions.update(
{"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)} {"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment