Unverified Commit cd91fc06 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Re-add xformers enable to UNet2DCondition (#1627)



* finish

* fix

* Update tests/models/test_models_unet_2d.py

* style
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>
parent ff65c2d7
...@@ -188,6 +188,39 @@ class ModelMixin(torch.nn.Module): ...@@ -188,6 +188,39 @@ class ModelMixin(torch.nn.Module):
if self._supports_gradient_checkpointing: if self._supports_gradient_checkpointing:
self.apply(partial(self._set_gradient_checkpointing, value=False)) self.apply(partial(self._set_gradient_checkpointing, value=False))
def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None:
# Recursively walk through all the children.
# Any children which exposes the set_use_memory_efficient_attention_xformers method
# gets the message
def fn_recursive_set_mem_eff(module: torch.nn.Module):
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
module.set_use_memory_efficient_attention_xformers(valid)
for child in module.children():
fn_recursive_set_mem_eff(child)
for module in self.children():
if isinstance(module, torch.nn.Module):
fn_recursive_set_mem_eff(module)
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.set_use_memory_efficient_attention_xformers(True)
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.set_use_memory_efficient_attention_xformers(False)
def save_pretrained( def save_pretrained(
self, self,
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
......
...@@ -30,6 +30,7 @@ from diffusers.utils import ( ...@@ -30,6 +30,7 @@ from diffusers.utils import (
torch_all_close, torch_all_close,
torch_device, torch_device,
) )
from diffusers.utils.import_utils import is_xformers_available
from parameterized import parameterized from parameterized import parameterized
from ..test_modeling_common import ModelTesterMixin from ..test_modeling_common import ModelTesterMixin
...@@ -255,6 +256,20 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -255,6 +256,20 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.enable_xformers_memory_efficient_attention()
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1._use_memory_efficient_attention_xformers
), "xformers is not enabled"
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS") @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
def test_gradient_checkpointing(self): def test_gradient_checkpointing(self):
# enable deterministic behavior for gradient checkpointing # enable deterministic behavior for gradient checkpointing
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment