Unverified Commit 196aef5a authored by Dimitri Barbot's avatar Dimitri Barbot Committed by GitHub
Browse files

Fix pipeline dtype unexpected change when using SDXL reference community...

Fix pipeline dtype unexpected change when using SDXL reference community pipelines in float16 mode (#10670)

Fix pipeline dtype unexpected change when using SDXL reference community pipelines
parent 7b100ce5
...@@ -193,7 +193,8 @@ class StableDiffusionXLControlNetReferencePipeline(StableDiffusionXLControlNetPi ...@@ -193,7 +193,8 @@ class StableDiffusionXLControlNetReferencePipeline(StableDiffusionXLControlNetPi
def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance): def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
refimage = refimage.to(device=device) refimage = refimage.to(device=device)
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
if needs_upcasting:
self.upcast_vae() self.upcast_vae()
refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
if refimage.dtype != self.vae.dtype: if refimage.dtype != self.vae.dtype:
...@@ -223,6 +224,11 @@ class StableDiffusionXLControlNetReferencePipeline(StableDiffusionXLControlNetPi ...@@ -223,6 +224,11 @@ class StableDiffusionXLControlNetReferencePipeline(StableDiffusionXLControlNetPi
# aligning device to prevent device errors when concating it with the latent model input # aligning device to prevent device errors when concating it with the latent model input
ref_image_latents = ref_image_latents.to(device=device, dtype=dtype) ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
return ref_image_latents return ref_image_latents
def prepare_ref_image( def prepare_ref_image(
......
...@@ -139,7 +139,8 @@ def retrieve_timesteps( ...@@ -139,7 +139,8 @@ def retrieve_timesteps(
class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline): class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance): def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
refimage = refimage.to(device=device) refimage = refimage.to(device=device)
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
if needs_upcasting:
self.upcast_vae() self.upcast_vae()
refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
if refimage.dtype != self.vae.dtype: if refimage.dtype != self.vae.dtype:
...@@ -169,6 +170,11 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline): ...@@ -169,6 +170,11 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
# aligning device to prevent device errors when concating it with the latent model input # aligning device to prevent device errors when concating it with the latent model input
ref_image_latents = ref_image_latents.to(device=device, dtype=dtype) ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
return ref_image_latents return ref_image_latents
def prepare_ref_image( def prepare_ref_image(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment