Unverified Commit 1908c476 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Deprecate `upcast_vae` in SDXL based pipelines (#12619)

* update

* update

* Revert "update"

This reverts commit 73906381ab76da96eb8f9b841177cd4f49861eb1.

* Revert "update"

This reverts commit 21a03f93ef0fbfa5f7a7d97708f75149b1d1b3b0.

* update

* update

* update

* update

* update
parent 759ea587
...@@ -22,11 +22,6 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz ...@@ -22,11 +22,6 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
FusedAttnProcessor2_0,
XFormersAttnProcessor,
)
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
...@@ -590,22 +585,12 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -590,22 +585,12 @@ class StableDiffusionXLInstructPix2PixPipeline(
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae
def upcast_vae(self): def upcast_vae(self):
dtype = self.vae.dtype deprecate(
self.vae.to(dtype=torch.float32) "upcast_vae",
use_torch_2_0_or_xformers = isinstance( "1.0.0",
self.vae.decoder.mid_block.attentions[0].processor, "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.",
(
AttnProcessor2_0,
XFormersAttnProcessor,
FusedAttnProcessor2_0,
),
) )
# if xformers or torch_2_0 is used attention block does not need self.vae.to(dtype=torch.float32)
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
self.vae.post_quant_conv.to(dtype)
self.vae.decoder.conv_in.to(dtype)
self.vae.decoder.mid_block.to(dtype)
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
......
...@@ -34,15 +34,12 @@ from ...loaders import ( ...@@ -34,15 +34,12 @@ from ...loaders import (
TextualInversionLoaderMixin, TextualInversionLoaderMixin,
) )
from ...models import AutoencoderKL, ImageProjection, MultiAdapter, T2IAdapter, UNet2DConditionModel from ...models import AutoencoderKL, ImageProjection, MultiAdapter, T2IAdapter, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
XFormersAttnProcessor,
)
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
PIL_INTERPOLATION, PIL_INTERPOLATION,
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
deprecate,
is_torch_xla_available, is_torch_xla_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
...@@ -779,21 +776,12 @@ class StableDiffusionXLAdapterPipeline( ...@@ -779,21 +776,12 @@ class StableDiffusionXLAdapterPipeline(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
def upcast_vae(self): def upcast_vae(self):
dtype = self.vae.dtype deprecate(
self.vae.to(dtype=torch.float32) "upcast_vae",
use_torch_2_0_or_xformers = isinstance( "1.0.0",
self.vae.decoder.mid_block.attentions[0].processor, "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.",
(
AttnProcessor2_0,
XFormersAttnProcessor,
),
) )
# if xformers or torch_2_0 is used attention block does not need self.vae.to(dtype=torch.float32)
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
self.vae.post_quant_conv.to(dtype)
self.vae.decoder.conv_in.to(dtype)
self.vae.decoder.mid_block.to(dtype)
# Copied from diffusers.pipelines.t2i_adapter.pipeline_stable_diffusion_adapter.StableDiffusionAdapterPipeline._default_height_width # Copied from diffusers.pipelines.t2i_adapter.pipeline_stable_diffusion_adapter.StableDiffusionAdapterPipeline._default_height_width
def _default_height_width(self, height, width, image): def _default_height_width(self, height, width, image):
......
...@@ -19,16 +19,12 @@ from transformers import ( ...@@ -19,16 +19,12 @@ from transformers import (
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
FusedAttnProcessor2_0,
XFormersAttnProcessor,
)
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
BaseOutput, BaseOutput,
deprecate,
is_invisible_watermark_available, is_invisible_watermark_available,
logging, logging,
scale_lora_layers, scale_lora_layers,
...@@ -457,22 +453,12 @@ class TextToVideoZeroSDXLPipeline( ...@@ -457,22 +453,12 @@ class TextToVideoZeroSDXLPipeline(
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae
def upcast_vae(self): def upcast_vae(self):
dtype = self.vae.dtype deprecate(
self.vae.to(dtype=torch.float32) "upcast_vae",
use_torch_2_0_or_xformers = isinstance( "1.0.0",
self.vae.decoder.mid_block.attentions[0].processor, "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.",
(
AttnProcessor2_0,
XFormersAttnProcessor,
FusedAttnProcessor2_0,
),
) )
# if xformers or torch_2_0 is used attention block does not need self.vae.to(dtype=torch.float32)
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
self.vae.post_quant_conv.to(dtype)
self.vae.decoder.conv_in.to(dtype)
self.vae.decoder.mid_block.to(dtype)
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
def _get_add_time_ids( def _get_add_time_ids(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment