Unverified Commit d6b86140 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Correctly keep vae in `float16` when using PyTorch 2 or xFormers (#4019)



* Update pipeline_stable_diffusion_xl.py

fix a bug

* Update pipeline_stable_diffusion_xl_img2img.py

* Update pipeline_stable_diffusion_xl_img2img.py

* Update pipeline_stable_diffusion_upscale.py

* style

---------
Co-authored-by: default avatarHu Ye <xiaohuzc@gmail.com>
parent e4f6c379
...@@ -748,12 +748,16 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -748,12 +748,16 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
# make sure the VAE is in float32 mode, as it overflows in float16 # make sure the VAE is in float32 mode, as it overflows in float16
self.vae.to(dtype=torch.float32) self.vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = self.vae.decoder.mid_block.attentions[0].processor in [ use_torch_2_0_or_xformers = isinstance(
AttnProcessor2_0, self.vae.decoder.mid_block.attentions[0].processor,
XFormersAttnProcessor, (
LoRAXFormersAttnProcessor, AttnProcessor2_0,
LoRAAttnProcessor2_0, XFormersAttnProcessor,
] LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need # if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory # to be in float32 which can save lots of memory
if not use_torch_2_0_or_xformers: if not use_torch_2_0_or_xformers:
......
...@@ -786,15 +786,18 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin): ...@@ -786,15 +786,18 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin):
# make sure the VAE is in float32 mode, as it overflows in float16 # make sure the VAE is in float32 mode, as it overflows in float16
self.vae.to(dtype=torch.float32) self.vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = self.vae.decoder.mid_block.attentions[0].processor in [ use_torch_2_0_or_xformers = isinstance(
AttnProcessor2_0, self.vae.decoder.mid_block.attentions[0].processor,
XFormersAttnProcessor, (
LoRAXFormersAttnProcessor, AttnProcessor2_0,
LoRAAttnProcessor2_0, XFormersAttnProcessor,
] LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need # if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory # to be in float32 which can save lots of memory
if not use_torch_2_0_or_xformers: if use_torch_2_0_or_xformers:
self.vae.post_quant_conv.to(latents.dtype) self.vae.post_quant_conv.to(latents.dtype)
self.vae.decoder.conv_in.to(latents.dtype) self.vae.decoder.conv_in.to(latents.dtype)
self.vae.decoder.mid_block.to(latents.dtype) self.vae.decoder.mid_block.to(latents.dtype)
......
...@@ -860,15 +860,18 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin): ...@@ -860,15 +860,18 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin):
# make sure the VAE is in float32 mode, as it overflows in float16 # make sure the VAE is in float32 mode, as it overflows in float16
self.vae.to(dtype=torch.float32) self.vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = self.vae.decoder.mid_block.attentions[0].processor in [ use_torch_2_0_or_xformers = isinstance(
AttnProcessor2_0, self.vae.decoder.mid_block.attentions[0].processor,
XFormersAttnProcessor, (
LoRAXFormersAttnProcessor, AttnProcessor2_0,
LoRAAttnProcessor2_0, XFormersAttnProcessor,
] LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need # if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory # to be in float32 which can save lots of memory
if not use_torch_2_0_or_xformers: if use_torch_2_0_or_xformers:
self.vae.post_quant_conv.to(latents.dtype) self.vae.post_quant_conv.to(latents.dtype)
self.vae.decoder.conv_in.to(latents.dtype) self.vae.decoder.conv_in.to(latents.dtype)
self.vae.decoder.mid_block.to(latents.dtype) self.vae.decoder.mid_block.to(latents.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment