Unverified Commit b10f5275 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Helper function to disable custom attention processors (#2791)

* Helper function to disable custom attention processors.

* Restore code deleted by mistake.

* Format

* Fix modeling_text_unet copy.
parent 7bc2fff1
...@@ -20,7 +20,7 @@ from torch.nn import functional as F ...@@ -20,7 +20,7 @@ from torch.nn import functional as F
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging from ..utils import BaseOutput, logging
from .attention_processor import AttentionProcessor from .attention_processor import AttentionProcessor, AttnProcessor
from .embeddings import TimestepEmbedding, Timesteps from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .unet_2d_blocks import ( from .unet_2d_blocks import (
...@@ -368,6 +368,13 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -368,6 +368,13 @@ class ControlNetModel(ModelMixin, ConfigMixin):
for name, module in self.named_children(): for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor) fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self.set_attn_processor(AttnProcessor())
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size):
r""" r"""
......
...@@ -21,7 +21,7 @@ import torch.utils.checkpoint ...@@ -21,7 +21,7 @@ import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin from ..loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, logging from ..utils import BaseOutput, logging
from .attention_processor import AttentionProcessor from .attention_processor import AttentionProcessor, AttnProcessor
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .unet_2d_blocks import ( from .unet_2d_blocks import (
...@@ -442,6 +442,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -442,6 +442,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
for name, module in self.named_children(): for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor) fn_recursive_attn_processor(name, module, processor)
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self.set_attn_processor(AttnProcessor())
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size):
r""" r"""
Enable sliced attention computation. Enable sliced attention computation.
......
...@@ -21,7 +21,7 @@ import torch.utils.checkpoint ...@@ -21,7 +21,7 @@ import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging from ..utils import BaseOutput, logging
from .attention_processor import AttentionProcessor from .attention_processor import AttentionProcessor, AttnProcessor
from .embeddings import TimestepEmbedding, Timesteps from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .transformer_temporal import TransformerTemporalModel from .transformer_temporal import TransformerTemporalModel
...@@ -372,6 +372,13 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): ...@@ -372,6 +372,13 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
for name, module in self.named_children(): for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor) fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self.set_attn_processor(AttnProcessor())
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
module.gradient_checkpointing = value module.gradient_checkpointing = value
......
...@@ -7,7 +7,7 @@ import torch.nn as nn ...@@ -7,7 +7,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin from ...models import ModelMixin
from ...models.attention import Attention from ...models.attention import Attention
from ...models.attention_processor import AttentionProcessor, AttnAddedKVProcessor from ...models.attention_processor import AttentionProcessor, AttnAddedKVProcessor, AttnProcessor
from ...models.dual_transformer_2d import DualTransformer2DModel from ...models.dual_transformer_2d import DualTransformer2DModel
from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from ...models.transformer_2d import Transformer2DModel from ...models.transformer_2d import Transformer2DModel
...@@ -533,6 +533,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -533,6 +533,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
for name, module in self.named_children(): for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor) fn_recursive_attn_processor(name, module, processor)
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self.set_attn_processor(AttnProcessor())
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size):
r""" r"""
Enable sliced attention computation. Enable sliced attention computation.
......
...@@ -22,7 +22,7 @@ import torch ...@@ -22,7 +22,7 @@ import torch
from parameterized import parameterized from parameterized import parameterized
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
from diffusers.models.attention_processor import AttnProcessor, LoRAAttnProcessor from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.utils import ( from diffusers.utils import (
floats_tensor, floats_tensor,
load_hf_numpy, load_hf_numpy,
...@@ -599,7 +599,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -599,7 +599,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample
model.set_attn_processor(AttnProcessor()) model.set_default_attn_processor()
with torch.no_grad(): with torch.no_grad():
new_sample = model(**inputs_dict).sample new_sample = model(**inputs_dict).sample
......
...@@ -35,7 +35,6 @@ from diffusers import ( ...@@ -35,7 +35,6 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
logging, logging,
) )
from diffusers.models.attention_processor import AttnProcessor
from diffusers.utils import load_numpy, nightly, slow, torch_device from diffusers.utils import load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu
...@@ -843,7 +842,7 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase): ...@@ -843,7 +842,7 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
"CompVis/stable-diffusion-v1-4", "CompVis/stable-diffusion-v1-4",
torch_dtype=torch.float16, torch_dtype=torch.float16,
) )
pipe.unet.set_attn_processor(AttnProcessor()) pipe.unet.set_default_attn_processor()
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
outputs = pipe(**inputs) outputs = pipe(**inputs)
...@@ -856,7 +855,7 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase): ...@@ -856,7 +855,7 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
"CompVis/stable-diffusion-v1-4", "CompVis/stable-diffusion-v1-4",
torch_dtype=torch.float16, torch_dtype=torch.float16,
) )
pipe.unet.set_attn_processor(AttnProcessor()) pipe.unet.set_default_attn_processor()
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated() torch.cuda.reset_max_memory_allocated()
......
...@@ -32,7 +32,6 @@ from diffusers import ( ...@@ -32,7 +32,6 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
logging, logging,
) )
from diffusers.models.attention_processor import AttnProcessor
from diffusers.utils import load_numpy, nightly, slow, torch_device from diffusers.utils import load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu
...@@ -410,7 +409,7 @@ class StableDiffusion2PipelineSlowTests(unittest.TestCase): ...@@ -410,7 +409,7 @@ class StableDiffusion2PipelineSlowTests(unittest.TestCase):
"stabilityai/stable-diffusion-2-base", "stabilityai/stable-diffusion-2-base",
torch_dtype=torch.float16, torch_dtype=torch.float16,
) )
pipe.unet.set_attn_processor(AttnProcessor()) pipe.unet.set_default_attn_processor()
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
outputs = pipe(**inputs) outputs = pipe(**inputs)
...@@ -423,7 +422,7 @@ class StableDiffusion2PipelineSlowTests(unittest.TestCase): ...@@ -423,7 +422,7 @@ class StableDiffusion2PipelineSlowTests(unittest.TestCase):
"stabilityai/stable-diffusion-2-base", "stabilityai/stable-diffusion-2-base",
torch_dtype=torch.float16, torch_dtype=torch.float16,
) )
pipe.unet.set_attn_processor(AttnProcessor()) pipe.unet.set_default_attn_processor()
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated() torch.cuda.reset_max_memory_allocated()
......
...@@ -25,7 +25,6 @@ import torch ...@@ -25,7 +25,6 @@ import torch
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from diffusers.models import UNet2DConditionModel from diffusers.models import UNet2DConditionModel
from diffusers.models.attention_processor import AttnProcessor
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
from diffusers.utils import torch_device from diffusers.utils import torch_device
...@@ -106,16 +105,16 @@ class ModelTesterMixin: ...@@ -106,16 +105,16 @@ class ModelTesterMixin:
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
if hasattr(model, "set_attn_processor"): if hasattr(model, "set_default_attn_processor"):
model.set_attn_processor(AttnProcessor()) model.set_default_attn_processor()
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
new_model = self.model_class.from_pretrained(tmpdirname) new_model = self.model_class.from_pretrained(tmpdirname)
if hasattr(new_model, "set_attn_processor"): if hasattr(new_model, "set_default_attn_processor"):
new_model.set_attn_processor(AttnProcessor()) new_model.set_default_attn_processor()
new_model.to(torch_device) new_model.to(torch_device)
with torch.no_grad(): with torch.no_grad():
...@@ -135,16 +134,16 @@ class ModelTesterMixin: ...@@ -135,16 +134,16 @@ class ModelTesterMixin:
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
if hasattr(model, "set_attn_processor"): if hasattr(model, "set_default_attn_processor"):
model.set_attn_processor(AttnProcessor()) model.set_default_attn_processor()
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, variant="fp16") model.save_pretrained(tmpdirname, variant="fp16")
new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16") new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16")
if hasattr(new_model, "set_attn_processor"): if hasattr(new_model, "set_default_attn_processor"):
new_model.set_attn_processor(AttnProcessor()) new_model.set_default_attn_processor()
# non-variant cannot be loaded # non-variant cannot be loaded
with self.assertRaises(OSError) as error_context: with self.assertRaises(OSError) as error_context:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment