Unverified Commit 0cc3a7a1 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Make sure we also change the config when setting...

Make sure we also change the config when setting `encoder_hid_dim_type=="text_proj"` and allow xformers (#3615)

* fix if

* make style

* make style

* add tests for xformers

* make style

* update
parent 9d3ff079
...@@ -215,11 +215,8 @@ class StableDiffusionTilingPipeline(DiffusionPipeline, StableDiffusionExtrasMixi ...@@ -215,11 +215,8 @@ class StableDiffusionTilingPipeline(DiffusionPipeline, StableDiffusionExtrasMixi
raise ValueError(f"`seed_tiles_mode` has to be a string or list of lists but is {type(prompt)}") raise ValueError(f"`seed_tiles_mode` has to be a string or list of lists but is {type(prompt)}")
if isinstance(seed_tiles_mode, str): if isinstance(seed_tiles_mode, str):
seed_tiles_mode = [[seed_tiles_mode for _ in range(len(row))] for row in prompt] seed_tiles_mode = [[seed_tiles_mode for _ in range(len(row))] for row in prompt]
if any( modes = [mode.value for mode in self.SeedTilesMode]
mode not in (modes := [mode.value for mode in self.SeedTilesMode]) if any(mode not in modes for row in seed_tiles_mode for mode in row):
for row in seed_tiles_mode
for mode in row
):
raise ValueError(f"Seed tiles mode must be one of {modes}") raise ValueError(f"Seed tiles mode must be one of {modes}")
if seed_reroll_regions is None: if seed_reroll_regions is None:
seed_reroll_regions = [] seed_reroll_regions = []
......
frog.png

108 KB

...@@ -166,22 +166,28 @@ class Attention(nn.Module): ...@@ -166,22 +166,28 @@ class Attention(nn.Module):
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
): ):
is_lora = hasattr(self, "processor") and isinstance( is_lora = hasattr(self, "processor") and isinstance(
self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor) self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor)
) )
is_custom_diffusion = hasattr(self, "processor") and isinstance( is_custom_diffusion = hasattr(self, "processor") and isinstance(
self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor) self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
) )
is_added_kv_processor = hasattr(self, "processor") and isinstance(
self.processor,
(
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
SlicedAttnAddedKVProcessor,
XFormersAttnAddedKVProcessor,
LoRAAttnAddedKVProcessor,
),
)
if use_memory_efficient_attention_xformers: if use_memory_efficient_attention_xformers:
if self.added_kv_proj_dim is not None: if is_added_kv_processor and (is_lora or is_custom_diffusion):
# TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
# which uses this type of cross attention ONLY because the attention mask of format
# [0, ..., -10.000, ..., 0, ...,] is not supported
raise NotImplementedError( raise NotImplementedError(
"Memory efficient attention with `xformers` is currently not supported when" f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}"
" `self.added_kv_proj_dim` is defined."
) )
elif not is_xformers_available(): if not is_xformers_available():
raise ModuleNotFoundError( raise ModuleNotFoundError(
( (
"Refer to https://github.com/facebookresearch/xformers for more information on how to install" "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
...@@ -233,6 +239,15 @@ class Attention(nn.Module): ...@@ -233,6 +239,15 @@ class Attention(nn.Module):
processor.load_state_dict(self.processor.state_dict()) processor.load_state_dict(self.processor.state_dict())
if hasattr(self.processor, "to_k_custom_diffusion"): if hasattr(self.processor, "to_k_custom_diffusion"):
processor.to(self.processor.to_k_custom_diffusion.weight.device) processor.to(self.processor.to_k_custom_diffusion.weight.device)
elif is_added_kv_processor:
# TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
# which uses this type of cross attention ONLY because the attention mask of format
# [0, ..., -10.000, ..., 0, ...,] is not supported
# throw warning
logger.info(
"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
)
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
else: else:
processor = XFormersAttnProcessor(attention_op=attention_op) processor = XFormersAttnProcessor(attention_op=attention_op)
else: else:
...@@ -889,6 +904,71 @@ class LoRAAttnAddedKVProcessor(nn.Module): ...@@ -889,6 +904,71 @@ class LoRAAttnAddedKVProcessor(nn.Module):
return hidden_states return hidden_states
class XFormersAttnAddedKVProcessor:
r"""
Processor for implementing memory efficient attention using xFormers.
Args:
attention_op (`Callable`, *optional*, defaults to `None`):
The base
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
operator.
"""
def __init__(self, attention_op: Optional[Callable] = None):
self.attention_op = attention_op
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
if not attn.only_cross_attention:
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
else:
key = encoder_hidden_states_key_proj
value = encoder_hidden_states_value_proj
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual
return hidden_states
class XFormersAttnProcessor: class XFormersAttnProcessor:
r""" r"""
Processor for implementing memory efficient attention using xFormers. Processor for implementing memory efficient attention using xFormers.
...@@ -1428,6 +1508,7 @@ AttentionProcessor = Union[ ...@@ -1428,6 +1508,7 @@ AttentionProcessor = Union[
AttnAddedKVProcessor, AttnAddedKVProcessor,
SlicedAttnAddedKVProcessor, SlicedAttnAddedKVProcessor,
AttnAddedKVProcessor2_0, AttnAddedKVProcessor2_0,
XFormersAttnAddedKVProcessor,
LoRAAttnProcessor, LoRAAttnProcessor,
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
LoRAAttnAddedKVProcessor, LoRAAttnAddedKVProcessor,
......
...@@ -261,6 +261,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -261,6 +261,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if encoder_hid_dim_type is None and encoder_hid_dim is not None: if encoder_hid_dim_type is None and encoder_hid_dim is not None:
encoder_hid_dim_type = "text_proj" encoder_hid_dim_type = "text_proj"
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
if encoder_hid_dim is None and encoder_hid_dim_type is not None: if encoder_hid_dim is None and encoder_hid_dim_type is not None:
......
...@@ -364,6 +364,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -364,6 +364,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if encoder_hid_dim_type is None and encoder_hid_dim is not None: if encoder_hid_dim_type is None and encoder_hid_dim is not None:
encoder_hid_dim_type = "text_proj" encoder_hid_dim_type = "text_proj"
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
if encoder_hid_dim is None and encoder_hid_dim_type is not None: if encoder_hid_dim is None and encoder_hid_dim_type is not None:
......
...@@ -28,6 +28,7 @@ from diffusers import ( ...@@ -28,6 +28,7 @@ from diffusers import (
IFSuperResolutionPipeline, IFSuperResolutionPipeline,
) )
from diffusers.models.attention_processor import AttnAddedKVProcessor from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import floats_tensor, load_numpy, require_torch_gpu, skip_mps, slow, torch_device from diffusers.utils.testing_utils import floats_tensor, load_numpy, require_torch_gpu, skip_mps, slow, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
...@@ -42,8 +43,6 @@ class IFPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, unittest.T ...@@ -42,8 +43,6 @@ class IFPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, unittest.T
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
test_xformers_attention = False
def get_dummy_components(self): def get_dummy_components(self):
return self._get_dummy_components() return self._get_dummy_components()
...@@ -81,6 +80,13 @@ class IFPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, unittest.T ...@@ -81,6 +80,13 @@ class IFPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, unittest.T
expected_max_diff=1e-2, expected_max_diff=1e-2,
) )
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)
@slow @slow
@require_torch_gpu @require_torch_gpu
......
...@@ -20,6 +20,7 @@ import torch ...@@ -20,6 +20,7 @@ import torch
from diffusers import IFImg2ImgPipeline from diffusers import IFImg2ImgPipeline
from diffusers.utils import floats_tensor from diffusers.utils import floats_tensor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import skip_mps, torch_device from diffusers.utils.testing_utils import skip_mps, torch_device
from ..pipeline_params import ( from ..pipeline_params import (
...@@ -37,8 +38,6 @@ class IFImg2ImgPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, uni ...@@ -37,8 +38,6 @@ class IFImg2ImgPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, uni
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
test_xformers_attention = False
def get_dummy_components(self): def get_dummy_components(self):
return self._get_dummy_components() return self._get_dummy_components()
...@@ -63,6 +62,13 @@ class IFImg2ImgPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, uni ...@@ -63,6 +62,13 @@ class IFImg2ImgPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, uni
def test_save_load_optional_components(self): def test_save_load_optional_components(self):
self._test_save_load_optional_components() self._test_save_load_optional_components()
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)
@unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
def test_save_load_float16(self): def test_save_load_float16(self):
# Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder # Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder
......
...@@ -20,6 +20,7 @@ import torch ...@@ -20,6 +20,7 @@ import torch
from diffusers import IFImg2ImgSuperResolutionPipeline from diffusers import IFImg2ImgSuperResolutionPipeline
from diffusers.utils import floats_tensor from diffusers.utils import floats_tensor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import skip_mps, torch_device from diffusers.utils.testing_utils import skip_mps, torch_device
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
...@@ -34,8 +35,6 @@ class IFImg2ImgSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineT ...@@ -34,8 +35,6 @@ class IFImg2ImgSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineT
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"original_image"}) batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"original_image"})
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
test_xformers_attention = False
def get_dummy_components(self): def get_dummy_components(self):
return self._get_superresolution_dummy_components() return self._get_superresolution_dummy_components()
...@@ -59,6 +58,13 @@ class IFImg2ImgSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineT ...@@ -59,6 +58,13 @@ class IFImg2ImgSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineT
return inputs return inputs
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)
def test_save_load_optional_components(self): def test_save_load_optional_components(self):
self._test_save_load_optional_components() self._test_save_load_optional_components()
......
...@@ -20,6 +20,7 @@ import torch ...@@ -20,6 +20,7 @@ import torch
from diffusers import IFInpaintingPipeline from diffusers import IFInpaintingPipeline
from diffusers.utils import floats_tensor from diffusers.utils import floats_tensor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import skip_mps, torch_device from diffusers.utils.testing_utils import skip_mps, torch_device
from ..pipeline_params import ( from ..pipeline_params import (
...@@ -37,8 +38,6 @@ class IFInpaintingPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, ...@@ -37,8 +38,6 @@ class IFInpaintingPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin,
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
test_xformers_attention = False
def get_dummy_components(self): def get_dummy_components(self):
return self._get_dummy_components() return self._get_dummy_components()
...@@ -62,6 +61,13 @@ class IFInpaintingPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, ...@@ -62,6 +61,13 @@ class IFInpaintingPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin,
return inputs return inputs
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)
def test_save_load_optional_components(self): def test_save_load_optional_components(self):
self._test_save_load_optional_components() self._test_save_load_optional_components()
......
...@@ -20,6 +20,7 @@ import torch ...@@ -20,6 +20,7 @@ import torch
from diffusers import IFInpaintingSuperResolutionPipeline from diffusers import IFInpaintingSuperResolutionPipeline
from diffusers.utils import floats_tensor from diffusers.utils import floats_tensor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import skip_mps, torch_device from diffusers.utils.testing_utils import skip_mps, torch_device
from ..pipeline_params import ( from ..pipeline_params import (
...@@ -37,8 +38,6 @@ class IFInpaintingSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipeli ...@@ -37,8 +38,6 @@ class IFInpaintingSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipeli
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS.union({"original_image"}) batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS.union({"original_image"})
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
test_xformers_attention = False
def get_dummy_components(self): def get_dummy_components(self):
return self._get_superresolution_dummy_components() return self._get_superresolution_dummy_components()
...@@ -64,6 +63,13 @@ class IFInpaintingSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipeli ...@@ -64,6 +63,13 @@ class IFInpaintingSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipeli
return inputs return inputs
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)
def test_save_load_optional_components(self): def test_save_load_optional_components(self):
self._test_save_load_optional_components() self._test_save_load_optional_components()
......
...@@ -20,6 +20,7 @@ import torch ...@@ -20,6 +20,7 @@ import torch
from diffusers import IFSuperResolutionPipeline from diffusers import IFSuperResolutionPipeline
from diffusers.utils import floats_tensor from diffusers.utils import floats_tensor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import skip_mps, torch_device from diffusers.utils.testing_utils import skip_mps, torch_device
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
...@@ -34,8 +35,6 @@ class IFSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMi ...@@ -34,8 +35,6 @@ class IFSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMi
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
test_xformers_attention = False
def get_dummy_components(self): def get_dummy_components(self):
return self._get_superresolution_dummy_components() return self._get_superresolution_dummy_components()
...@@ -57,6 +56,13 @@ class IFSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMi ...@@ -57,6 +56,13 @@ class IFSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMi
return inputs return inputs
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)
def test_save_load_optional_components(self): def test_save_load_optional_components(self):
self._test_save_load_optional_components() self._test_save_load_optional_components()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment