Unverified Commit d50baf0c authored by Tolga Cangöz's avatar Tolga Cangöz Committed by GitHub
Browse files

Fix image upcasting (#7858)



Fix image's upcasting before `vae.encode()` when using `fp16`
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent c2217142
......@@ -1419,7 +1419,6 @@ class LEditsPPPipelineStableDiffusionXL(
if needs_upcasting:
image = image.float()
self.upcast_vae()
image = image.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
x0 = self.vae.encode(image).latent_dist.mode()
x0 = x0.to(dtype)
......
......@@ -525,8 +525,8 @@ class StableDiffusionXLInstructPix2PixPipeline(
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
if needs_upcasting:
image = image.float()
self.upcast_vae()
image = image.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment