Unverified Commit 8ead643b authored by Andreas Jörg's avatar Andreas Jörg Committed by GitHub
Browse files

[examples/controlnet/train_controlnet_sd3.py] Fixes #11050 - Cast...


[examples/controlnet/train_controlnet_sd3.py] Fixes #11050 - Cast prompt_embeds and pooled_prompt_embeds to weight_dtype to prevent dtype mismatch (#11051)

Fix: dtype mismatch of prompt embeddings in sd3 controlnet training
Co-authored-by: default avatarAndreas Jörg <andreasjoerg@MacBook-Pro-von-Andreas-2.fritz.box>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 124ac3e8
...@@ -1283,8 +1283,8 @@ def main(args): ...@@ -1283,8 +1283,8 @@ def main(args):
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
# Get the text embedding for conditioning # Get the text embedding for conditioning
prompt_embeds = batch["prompt_embeds"] prompt_embeds = batch["prompt_embeds"].to(dtype=weight_dtype)
pooled_prompt_embeds = batch["pooled_prompt_embeds"] pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(dtype=weight_dtype)
# controlnet(s) inference # controlnet(s) inference
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype) controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment