"docs/vscode:/vscode.git/clone" did not exist on "7ebc3ad8d720004df01fdde424dcaefc9f43bd6f"
Unverified Commit b10f5275 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Helper function to disable custom attention processors (#2791)

* Helper function to disable custom attention processors.

* Restore code deleted by mistake.

* Format

* Fix modeling_text_unet copy.
parent 7bc2fff1
......@@ -20,7 +20,7 @@ from torch.nn import functional as F
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging
from .attention_processor import AttentionProcessor
from .attention_processor import AttentionProcessor, AttnProcessor
from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unet_2d_blocks import (
......@@ -368,6 +368,13 @@ class ControlNetModel(ModelMixin, ConfigMixin):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self.set_attn_processor(AttnProcessor())
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size):
r"""
......
......@@ -21,7 +21,7 @@ import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, logging
from .attention_processor import AttentionProcessor
from .attention_processor import AttentionProcessor, AttnProcessor
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unet_2d_blocks import (
......@@ -442,6 +442,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self.set_attn_processor(AttnProcessor())
def set_attention_slice(self, slice_size):
r"""
Enable sliced attention computation.
......
......@@ -21,7 +21,7 @@ import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging
from .attention_processor import AttentionProcessor
from .attention_processor import AttentionProcessor, AttnProcessor
from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .transformer_temporal import TransformerTemporalModel
......@@ -372,6 +372,13 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self.set_attn_processor(AttnProcessor())
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
module.gradient_checkpointing = value
......
......@@ -7,7 +7,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin
from ...models.attention import Attention
from ...models.attention_processor import AttentionProcessor, AttnAddedKVProcessor
from ...models.attention_processor import AttentionProcessor, AttnAddedKVProcessor, AttnProcessor
from ...models.dual_transformer_2d import DualTransformer2DModel
from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from ...models.transformer_2d import Transformer2DModel
......@@ -533,6 +533,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self.set_attn_processor(AttnProcessor())
def set_attention_slice(self, slice_size):
r"""
Enable sliced attention computation.
......
......@@ -22,7 +22,7 @@ import torch
from parameterized import parameterized
from diffusers import UNet2DConditionModel
from diffusers.models.attention_processor import AttnProcessor, LoRAAttnProcessor
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.utils import (
floats_tensor,
load_hf_numpy,
......@@ -599,7 +599,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
with torch.no_grad():
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample
model.set_attn_processor(AttnProcessor())
model.set_default_attn_processor()
with torch.no_grad():
new_sample = model(**inputs_dict).sample
......
......@@ -35,7 +35,6 @@ from diffusers import (
UNet2DConditionModel,
logging,
)
from diffusers.models.attention_processor import AttnProcessor
from diffusers.utils import load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu
......@@ -843,7 +842,7 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
"CompVis/stable-diffusion-v1-4",
torch_dtype=torch.float16,
)
pipe.unet.set_attn_processor(AttnProcessor())
pipe.unet.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
outputs = pipe(**inputs)
......@@ -856,7 +855,7 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
"CompVis/stable-diffusion-v1-4",
torch_dtype=torch.float16,
)
pipe.unet.set_attn_processor(AttnProcessor())
pipe.unet.set_default_attn_processor()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
......
......@@ -32,7 +32,6 @@ from diffusers import (
UNet2DConditionModel,
logging,
)
from diffusers.models.attention_processor import AttnProcessor
from diffusers.utils import load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu
......@@ -410,7 +409,7 @@ class StableDiffusion2PipelineSlowTests(unittest.TestCase):
"stabilityai/stable-diffusion-2-base",
torch_dtype=torch.float16,
)
pipe.unet.set_attn_processor(AttnProcessor())
pipe.unet.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
outputs = pipe(**inputs)
......@@ -423,7 +422,7 @@ class StableDiffusion2PipelineSlowTests(unittest.TestCase):
"stabilityai/stable-diffusion-2-base",
torch_dtype=torch.float16,
)
pipe.unet.set_attn_processor(AttnProcessor())
pipe.unet.set_default_attn_processor()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
......
......@@ -25,7 +25,6 @@ import torch
from requests.exceptions import HTTPError
from diffusers.models import UNet2DConditionModel
from diffusers.models.attention_processor import AttnProcessor
from diffusers.training_utils import EMAModel
from diffusers.utils import torch_device
......@@ -106,16 +105,16 @@ class ModelTesterMixin:
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
if hasattr(model, "set_attn_processor"):
model.set_attn_processor(AttnProcessor())
if hasattr(model, "set_default_attn_processor"):
model.set_default_attn_processor()
model.to(torch_device)
model.eval()
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
new_model = self.model_class.from_pretrained(tmpdirname)
if hasattr(new_model, "set_attn_processor"):
new_model.set_attn_processor(AttnProcessor())
if hasattr(new_model, "set_default_attn_processor"):
new_model.set_default_attn_processor()
new_model.to(torch_device)
with torch.no_grad():
......@@ -135,16 +134,16 @@ class ModelTesterMixin:
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
if hasattr(model, "set_attn_processor"):
model.set_attn_processor(AttnProcessor())
if hasattr(model, "set_default_attn_processor"):
model.set_default_attn_processor()
model.to(torch_device)
model.eval()
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, variant="fp16")
new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16")
if hasattr(new_model, "set_attn_processor"):
new_model.set_attn_processor(AttnProcessor())
if hasattr(new_model, "set_default_attn_processor"):
new_model.set_default_attn_processor()
# non-variant cannot be loaded
with self.assertRaises(OSError) as error_context:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment