Unverified Commit 0cc3a7a1 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Make sure we also change the config when setting...

Make sure we also change the config when setting `encoder_hid_dim_type=="text_proj"` and allow xformers (#3615)

* fix if

* make style

* make style

* add tests for xformers

* make style

* update
parent 9d3ff079
......@@ -215,11 +215,8 @@ class StableDiffusionTilingPipeline(DiffusionPipeline, StableDiffusionExtrasMixi
raise ValueError(f"`seed_tiles_mode` has to be a string or list of lists but is {type(prompt)}")
if isinstance(seed_tiles_mode, str):
seed_tiles_mode = [[seed_tiles_mode for _ in range(len(row))] for row in prompt]
if any(
mode not in (modes := [mode.value for mode in self.SeedTilesMode])
for row in seed_tiles_mode
for mode in row
):
modes = [mode.value for mode in self.SeedTilesMode]
if any(mode not in modes for row in seed_tiles_mode for mode in row):
raise ValueError(f"Seed tiles mode must be one of {modes}")
if seed_reroll_regions is None:
seed_reroll_regions = []
......
frog.png

108 KB

......@@ -166,22 +166,28 @@ class Attention(nn.Module):
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
):
is_lora = hasattr(self, "processor") and isinstance(
self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor)
self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor)
)
is_custom_diffusion = hasattr(self, "processor") and isinstance(
self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
)
is_added_kv_processor = hasattr(self, "processor") and isinstance(
self.processor,
(
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
SlicedAttnAddedKVProcessor,
XFormersAttnAddedKVProcessor,
LoRAAttnAddedKVProcessor,
),
)
if use_memory_efficient_attention_xformers:
if self.added_kv_proj_dim is not None:
# TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
# which uses this type of cross attention ONLY because the attention mask of format
# [0, ..., -10.000, ..., 0, ...,] is not supported
if is_added_kv_processor and (is_lora or is_custom_diffusion):
raise NotImplementedError(
"Memory efficient attention with `xformers` is currently not supported when"
" `self.added_kv_proj_dim` is defined."
f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}"
)
elif not is_xformers_available():
if not is_xformers_available():
raise ModuleNotFoundError(
(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
......@@ -233,6 +239,15 @@ class Attention(nn.Module):
processor.load_state_dict(self.processor.state_dict())
if hasattr(self.processor, "to_k_custom_diffusion"):
processor.to(self.processor.to_k_custom_diffusion.weight.device)
elif is_added_kv_processor:
# TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
# which uses this type of cross attention ONLY because the attention mask of format
# [0, ..., -10.000, ..., 0, ...,] is not supported
# throw warning
logger.info(
"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
)
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
else:
processor = XFormersAttnProcessor(attention_op=attention_op)
else:
......@@ -889,6 +904,71 @@ class LoRAAttnAddedKVProcessor(nn.Module):
return hidden_states
class XFormersAttnAddedKVProcessor:
r"""
Processor for implementing memory efficient attention using xFormers.
Args:
attention_op (`Callable`, *optional*, defaults to `None`):
The base
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
operator.
"""
def __init__(self, attention_op: Optional[Callable] = None):
self.attention_op = attention_op
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
if not attn.only_cross_attention:
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
else:
key = encoder_hidden_states_key_proj
value = encoder_hidden_states_value_proj
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual
return hidden_states
class XFormersAttnProcessor:
r"""
Processor for implementing memory efficient attention using xFormers.
......@@ -1428,6 +1508,7 @@ AttentionProcessor = Union[
AttnAddedKVProcessor,
SlicedAttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
XFormersAttnAddedKVProcessor,
LoRAAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnAddedKVProcessor,
......
......@@ -261,6 +261,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
encoder_hid_dim_type = "text_proj"
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
......
......@@ -364,6 +364,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
encoder_hid_dim_type = "text_proj"
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
......
......@@ -28,6 +28,7 @@ from diffusers import (
IFSuperResolutionPipeline,
)
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import floats_tensor, load_numpy, require_torch_gpu, skip_mps, slow, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
......@@ -42,8 +43,6 @@ class IFPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, unittest.T
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
test_xformers_attention = False
def get_dummy_components(self):
return self._get_dummy_components()
......@@ -81,6 +80,13 @@ class IFPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, unittest.T
expected_max_diff=1e-2,
)
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)
@slow
@require_torch_gpu
......
......@@ -20,6 +20,7 @@ import torch
from diffusers import IFImg2ImgPipeline
from diffusers.utils import floats_tensor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import skip_mps, torch_device
from ..pipeline_params import (
......@@ -37,8 +38,6 @@ class IFImg2ImgPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, uni
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
test_xformers_attention = False
def get_dummy_components(self):
return self._get_dummy_components()
......@@ -63,6 +62,13 @@ class IFImg2ImgPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, uni
def test_save_load_optional_components(self):
self._test_save_load_optional_components()
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)
@unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
def test_save_load_float16(self):
# Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder
......
......@@ -20,6 +20,7 @@ import torch
from diffusers import IFImg2ImgSuperResolutionPipeline
from diffusers.utils import floats_tensor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import skip_mps, torch_device
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
......@@ -34,8 +35,6 @@ class IFImg2ImgSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineT
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"original_image"})
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
test_xformers_attention = False
def get_dummy_components(self):
return self._get_superresolution_dummy_components()
......@@ -59,6 +58,13 @@ class IFImg2ImgSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineT
return inputs
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)
def test_save_load_optional_components(self):
self._test_save_load_optional_components()
......
......@@ -20,6 +20,7 @@ import torch
from diffusers import IFInpaintingPipeline
from diffusers.utils import floats_tensor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import skip_mps, torch_device
from ..pipeline_params import (
......@@ -37,8 +38,6 @@ class IFInpaintingPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin,
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
test_xformers_attention = False
def get_dummy_components(self):
return self._get_dummy_components()
......@@ -62,6 +61,13 @@ class IFInpaintingPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin,
return inputs
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)
def test_save_load_optional_components(self):
self._test_save_load_optional_components()
......
......@@ -20,6 +20,7 @@ import torch
from diffusers import IFInpaintingSuperResolutionPipeline
from diffusers.utils import floats_tensor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import skip_mps, torch_device
from ..pipeline_params import (
......@@ -37,8 +38,6 @@ class IFInpaintingSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipeli
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS.union({"original_image"})
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
test_xformers_attention = False
def get_dummy_components(self):
return self._get_superresolution_dummy_components()
......@@ -64,6 +63,13 @@ class IFInpaintingSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipeli
return inputs
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)
def test_save_load_optional_components(self):
self._test_save_load_optional_components()
......
......@@ -20,6 +20,7 @@ import torch
from diffusers import IFSuperResolutionPipeline
from diffusers.utils import floats_tensor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import skip_mps, torch_device
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
......@@ -34,8 +35,6 @@ class IFSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMi
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
test_xformers_attention = False
def get_dummy_components(self):
return self._get_superresolution_dummy_components()
......@@ -57,6 +56,13 @@ class IFSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMi
return inputs
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)
def test_save_load_optional_components(self):
self._test_save_load_optional_components()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment