Unverified Commit 544710ef authored by Bagheera's avatar Bagheera Committed by GitHub
Browse files

diffusers#7426 fix stable diffusion xl inference on MPS when dtypes shift...


diffusers#7426 fix stable diffusion xl inference on MPS when dtypes shift unexpectedly due to pytorch bugs (#7446)

* mps: fix XL pipeline inference at training time due to upstream pytorch bug

* Update src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* apply the safe-guarding logic elsewhere.

---------
Co-authored-by: default avatarbghira <bghira@users.github.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 443aa14e
......@@ -1193,7 +1193,16 @@ class StableDiffusionXLPipeline(
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
else:
raise ValueError(
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
)
if callback_on_step_end is not None:
callback_kwargs = {}
......@@ -1228,6 +1237,14 @@ class StableDiffusionXLPipeline(
if needs_upcasting:
self.upcast_vae()
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
elif latents.dtype != self.vae.dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
self.vae = self.vae.to(latents.dtype)
else:
raise ValueError(
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
)
# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
......
......@@ -1370,7 +1370,16 @@ class StableDiffusionXLImg2ImgPipeline(
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
else:
raise ValueError(
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
)
if callback_on_step_end is not None:
callback_kwargs = {}
......@@ -1405,6 +1414,14 @@ class StableDiffusionXLImg2ImgPipeline(
if needs_upcasting:
self.upcast_vae()
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
elif latents.dtype != self.vae.dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
self.vae = self.vae.to(latents.dtype)
else:
raise ValueError(
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
)
# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
......
......@@ -1720,7 +1720,16 @@ class StableDiffusionXLInpaintPipeline(
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
else:
raise ValueError(
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
)
if num_channels_unet == 4:
init_latents_proper = image_latents
......@@ -1772,6 +1781,14 @@ class StableDiffusionXLInpaintPipeline(
if needs_upcasting:
self.upcast_vae()
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
elif latents.dtype != self.vae.dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
self.vae = self.vae.to(latents.dtype)
else:
raise ValueError(
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
)
# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
......
......@@ -918,7 +918,16 @@ class StableDiffusionXLInstructPix2PixPipeline(
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
else:
raise ValueError(
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
......@@ -937,6 +946,14 @@ class StableDiffusionXLInstructPix2PixPipeline(
if needs_upcasting:
self.upcast_vae()
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
elif latents.dtype != self.vae.dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
self.vae = self.vae.to(latents.dtype)
else:
raise ValueError(
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
)
# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment