"...text-generation-inference.git" did not exist on "fe710af25f9297afca1ef2d974a0def654775bb7"
Unverified Commit 0ea51627 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Core] Fix dtype in InstructPix2Pix SDXL while computing `image_latents` (#5013)

* check out dtypes.

* check out dtypes.

* check out dtypes.

* check out dtypes.

* check out dtypes.

* check out dtypes.

* check out dtypes.

* potential fix

* check out dtypes.

* check out dtypes.

* working?
parent 6d6a08f1
...@@ -495,7 +495,8 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -495,7 +495,8 @@ class StableDiffusionXLInstructPix2PixPipeline(
image_latents = image image_latents = image
else: else:
# make sure the VAE is in float32 mode, as it overflows in float16 # make sure the VAE is in float32 mode, as it overflows in float16
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
if needs_upcasting:
self.upcast_vae() self.upcast_vae()
image = image.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) image = image.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
...@@ -511,6 +512,10 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -511,6 +512,10 @@ class StableDiffusionXLInstructPix2PixPipeline(
else: else:
image_latents = self.vae.encode(image).latent_dist.mode() image_latents = self.vae.encode(image).latent_dist.mode()
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand image_latents for batch_size # expand image_latents for batch_size
deprecation_message = ( deprecation_message = (
...@@ -533,6 +538,9 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -533,6 +538,9 @@ class StableDiffusionXLInstructPix2PixPipeline(
uncond_image_latents = torch.zeros_like(image_latents) uncond_image_latents = torch.zeros_like(image_latents)
image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0)
if image_latents.dtype != self.vae.dtype:
image_latents = image_latents.to(dtype=self.vae.dtype)
return image_latents return image_latents
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment