Unverified Commit dff5ff35 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[SDXL LoRA] fix batch size lora (#4509)

fix batch size lora
parent b2456717
......@@ -1103,11 +1103,11 @@ def main(args):
"time_ids": add_time_ids.repeat(elems_to_repeat, 1),
"text_embeds": unet_add_text_embeds.repeat(elems_to_repeat, 1),
}
prompt_embeds = prompt_embeds.repeat(elems_to_repeat, 1, 1)
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1)
model_pred = unet(
noisy_model_input,
timesteps,
prompt_embeds,
prompt_embeds_input,
added_cond_kwargs=unet_added_conditions,
).sample
else:
......@@ -1119,9 +1119,9 @@ def main(args):
text_input_ids_list=[tokens_one, tokens_two],
)
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat, 1)})
prompt_embeds = prompt_embeds.repeat(elems_to_repeat, 1, 1)
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1)
model_pred = unet(
noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions
noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions
).sample
# Get the target for loss depending on the prediction type
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment