Unverified Commit dff5ff35 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[SDXL LoRA] fix batch size lora (#4509)

fix batch size lora
parent b2456717
...@@ -1103,11 +1103,11 @@ def main(args): ...@@ -1103,11 +1103,11 @@ def main(args):
"time_ids": add_time_ids.repeat(elems_to_repeat, 1), "time_ids": add_time_ids.repeat(elems_to_repeat, 1),
"text_embeds": unet_add_text_embeds.repeat(elems_to_repeat, 1), "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat, 1),
} }
prompt_embeds = prompt_embeds.repeat(elems_to_repeat, 1, 1) prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1)
model_pred = unet( model_pred = unet(
noisy_model_input, noisy_model_input,
timesteps, timesteps,
prompt_embeds, prompt_embeds_input,
added_cond_kwargs=unet_added_conditions, added_cond_kwargs=unet_added_conditions,
).sample ).sample
else: else:
...@@ -1119,9 +1119,9 @@ def main(args): ...@@ -1119,9 +1119,9 @@ def main(args):
text_input_ids_list=[tokens_one, tokens_two], text_input_ids_list=[tokens_one, tokens_two],
) )
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat, 1)}) unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat, 1)})
prompt_embeds = prompt_embeds.repeat(elems_to_repeat, 1, 1) prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1)
model_pred = unet( model_pred = unet(
noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions
).sample ).sample
# Get the target for loss depending on the prediction type # Get the target for loss depending on the prediction type
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment