Unverified Commit 161449d5 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[SDXL DreamBooth LoRA] multiple fixes (#4262)

* add automatic licensing.

* debugging

* debugging

* more debugging

* more debugging.

* run make fix-copies.

* change to default tracker.
parent 34abee09
......@@ -94,6 +94,10 @@ These are LoRA adaption weights for {base_model}. The weights were trained on {p
LoRA for the text encoder was enabled: {train_text_encoder}.
Special VAE used for training: {vae_path}.
## License
[SDXL 0.9 Research License](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9/blob/main/LICENSE.md)
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
......@@ -1077,28 +1081,32 @@ def main(args):
# (this is the forward diffusion process)
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
# Calculate the elements to repeat depending on the use of prior-preservation.
elems_to_repeat = bsz // 2 if args.with_prior_preservation else bsz
# Predict the noise residual
if not args.train_text_encoder:
unet_added_conditions = {
"time_ids": add_time_ids.repeat(bsz, 1),
"text_embeds": unet_add_text_embeds.repeat(bsz, 1),
"time_ids": add_time_ids.repeat(elems_to_repeat, 1),
"text_embeds": unet_add_text_embeds.repeat(elems_to_repeat, 1),
}
prompt_embeds = prompt_embeds.repeat(elems_to_repeat, 1, 1)
model_pred = unet(
noisy_model_input,
timesteps,
prompt_embeds.repeat(bsz, 1, 1),
prompt_embeds,
added_cond_kwargs=unet_added_conditions,
).sample
else:
unet_added_conditions = {"time_ids": add_time_ids.repeat(bsz, 1)}
unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat, 1)}
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=None,
prompt=None,
text_input_ids_list=[tokens_one, tokens_two],
)
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(bsz, 1)})
prompt_embeds = prompt_embeds.repeat(bsz, 1, 1)
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat, 1)})
prompt_embeds = prompt_embeds.repeat(elems_to_repeat, 1, 1)
model_pred = unet(
noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions
).sample
......@@ -1194,12 +1202,8 @@ def main(args):
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
text_encoder=accelerator.unwrap_model(text_encoder_one)
if args.train_text_encoder
else text_encoder_one,
text_encoder_2=accelerator.unwrap_model(text_encoder_two)
if args.train_text_encoder
else text_encoder_two,
text_encoder=accelerator.unwrap_model(text_encoder_one),
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
unet=accelerator.unwrap_model(unet),
revision=args.revision,
torch_dtype=weight_dtype,
......
......@@ -836,6 +836,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
aug_emb = self.add_embedding(text_embs, image_embs)
elif self.config.addition_embed_type == "text_time":
# SDXL - style
if "text_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
......
......@@ -943,6 +943,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
aug_emb = self.add_embedding(text_embs, image_embs)
elif self.config.addition_embed_type == "text_time":
# SDXL - style
if "text_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment