Unverified Commit 161449d5 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[SDXL DreamBooth LoRA] multiple fixes (#4262)

* add automatic licensing.

* debugging

* debugging

* more debugging

* more debugging.

* run make fix-copies.

* change to default tracker.
parent 34abee09
...@@ -94,6 +94,10 @@ These are LoRA adaption weights for {base_model}. The weights were trained on {p ...@@ -94,6 +94,10 @@ These are LoRA adaption weights for {base_model}. The weights were trained on {p
LoRA for the text encoder was enabled: {train_text_encoder}. LoRA for the text encoder was enabled: {train_text_encoder}.
Special VAE used for training: {vae_path}. Special VAE used for training: {vae_path}.
## License
[SDXL 0.9 Research License](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9/blob/main/LICENSE.md)
""" """
with open(os.path.join(repo_folder, "README.md"), "w") as f: with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card) f.write(yaml + model_card)
...@@ -1077,28 +1081,32 @@ def main(args): ...@@ -1077,28 +1081,32 @@ def main(args):
# (this is the forward diffusion process) # (this is the forward diffusion process)
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
# Calculate the elements to repeat depending on the use of prior-preservation.
elems_to_repeat = bsz // 2 if args.with_prior_preservation else bsz
# Predict the noise residual # Predict the noise residual
if not args.train_text_encoder: if not args.train_text_encoder:
unet_added_conditions = { unet_added_conditions = {
"time_ids": add_time_ids.repeat(bsz, 1), "time_ids": add_time_ids.repeat(elems_to_repeat, 1),
"text_embeds": unet_add_text_embeds.repeat(bsz, 1), "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat, 1),
} }
prompt_embeds = prompt_embeds.repeat(elems_to_repeat, 1, 1)
model_pred = unet( model_pred = unet(
noisy_model_input, noisy_model_input,
timesteps, timesteps,
prompt_embeds.repeat(bsz, 1, 1), prompt_embeds,
added_cond_kwargs=unet_added_conditions, added_cond_kwargs=unet_added_conditions,
).sample ).sample
else: else:
unet_added_conditions = {"time_ids": add_time_ids.repeat(bsz, 1)} unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat, 1)}
prompt_embeds, pooled_prompt_embeds = encode_prompt( prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two], text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=None, tokenizers=None,
prompt=None, prompt=None,
text_input_ids_list=[tokens_one, tokens_two], text_input_ids_list=[tokens_one, tokens_two],
) )
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(bsz, 1)}) unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat, 1)})
prompt_embeds = prompt_embeds.repeat(bsz, 1, 1) prompt_embeds = prompt_embeds.repeat(elems_to_repeat, 1, 1)
model_pred = unet( model_pred = unet(
noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions
).sample ).sample
...@@ -1194,12 +1202,8 @@ def main(args): ...@@ -1194,12 +1202,8 @@ def main(args):
pipeline = StableDiffusionXLPipeline.from_pretrained( pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
vae=vae, vae=vae,
text_encoder=accelerator.unwrap_model(text_encoder_one) text_encoder=accelerator.unwrap_model(text_encoder_one),
if args.train_text_encoder text_encoder_2=accelerator.unwrap_model(text_encoder_two),
else text_encoder_one,
text_encoder_2=accelerator.unwrap_model(text_encoder_two)
if args.train_text_encoder
else text_encoder_two,
unet=accelerator.unwrap_model(unet), unet=accelerator.unwrap_model(unet),
revision=args.revision, revision=args.revision,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
......
...@@ -836,6 +836,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -836,6 +836,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
aug_emb = self.add_embedding(text_embs, image_embs) aug_emb = self.add_embedding(text_embs, image_embs)
elif self.config.addition_embed_type == "text_time": elif self.config.addition_embed_type == "text_time":
# SDXL - style
if "text_embeds" not in added_cond_kwargs: if "text_embeds" not in added_cond_kwargs:
raise ValueError( raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
......
...@@ -943,6 +943,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -943,6 +943,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
aug_emb = self.add_embedding(text_embs, image_embs) aug_emb = self.add_embedding(text_embs, image_embs)
elif self.config.addition_embed_type == "text_time": elif self.config.addition_embed_type == "text_time":
# SDXL - style
if "text_embeds" not in added_cond_kwargs: if "text_embeds" not in added_cond_kwargs:
raise ValueError( raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires" f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment