Unverified Commit c0f867af authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Fix temb attention (#3607)

* Fix temb attention

* Apply suggestions from code review

* make style

* Add tests and fix docker

* Apply suggestions from code review
parent c6ae8837
...@@ -38,6 +38,8 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip && \ ...@@ -38,6 +38,8 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip && \
scipy \ scipy \
tensorboard \ tensorboard \
transformers \ transformers \
omegaconf omegaconf \
pytorch-lightning \
xformers
CMD ["/bin/bash"] CMD ["/bin/bash"]
...@@ -540,9 +540,14 @@ class LoRAAttnProcessor(nn.Module): ...@@ -540,9 +540,14 @@ class LoRAAttnProcessor(nn.Module):
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): def __call__(
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
):
residual = hidden_states residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim input_ndim = hidden_states.ndim
if input_ndim == 4: if input_ndim == 4:
...@@ -905,9 +910,13 @@ class XFormersAttnProcessor: ...@@ -905,9 +910,13 @@ class XFormersAttnProcessor:
hidden_states: torch.FloatTensor, hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
): ):
residual = hidden_states residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim input_ndim = hidden_states.ndim
if input_ndim == 4: if input_ndim == 4:
...@@ -1081,9 +1090,14 @@ class LoRAXFormersAttnProcessor(nn.Module): ...@@ -1081,9 +1090,14 @@ class LoRAXFormersAttnProcessor(nn.Module):
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): def __call__(
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
):
residual = hidden_states residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim input_ndim = hidden_states.ndim
if input_ndim == 4: if input_ndim == 4:
...@@ -1334,8 +1348,12 @@ class SlicedAttnAddedKVProcessor: ...@@ -1334,8 +1348,12 @@ class SlicedAttnAddedKVProcessor:
def __init__(self, slice_size): def __init__(self, slice_size):
self.slice_size = slice_size self.slice_size = slice_size
def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None): def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
residual = hidden_states residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
......
...@@ -577,3 +577,9 @@ def enable_full_determinism(): ...@@ -577,3 +577,9 @@ def enable_full_determinism():
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
def disable_full_determinism():
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ""
torch.use_deterministic_algorithms(False)
...@@ -37,16 +37,18 @@ from diffusers import ( ...@@ -37,16 +37,18 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
logging, logging,
) )
from diffusers.models.attention_processor import AttnProcessor from diffusers.models.attention_processor import AttnProcessor, LoRAXFormersAttnProcessor
from diffusers.utils import load_numpy, nightly, slow, torch_device from diffusers.utils import load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
CaptureLogger, CaptureLogger,
disable_full_determinism,
enable_full_determinism, enable_full_determinism,
require_torch_2, require_torch_2,
require_torch_gpu, require_torch_gpu,
run_test_in_subprocess, run_test_in_subprocess,
) )
from ...models.test_lora_layers import create_unet_lora_layers
from ...models.test_models_unet_2d_condition import create_lora_layers from ...models.test_models_unet_2d_condition import create_lora_layers
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
...@@ -366,6 +368,56 @@ class StableDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTester ...@@ -366,6 +368,56 @@ class StableDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTester
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@unittest.skipIf(not torch.cuda.is_available(), reason="xformers requires cuda")
def test_stable_diffusion_attn_processors(self):
disable_full_determinism()
device = "cuda" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
# run normal sd pipe
image = sd_pipe(**inputs).images
assert image.shape == (1, 64, 64, 3)
# run xformers attention
sd_pipe.enable_xformers_memory_efficient_attention()
image = sd_pipe(**inputs).images
assert image.shape == (1, 64, 64, 3)
# run attention slicing
sd_pipe.enable_attention_slicing()
image = sd_pipe(**inputs).images
assert image.shape == (1, 64, 64, 3)
# run vae attention slicing
sd_pipe.enable_vae_slicing()
image = sd_pipe(**inputs).images
assert image.shape == (1, 64, 64, 3)
# run lora attention
attn_processors, _ = create_unet_lora_layers(sd_pipe.unet)
attn_processors = {k: v.to("cuda") for k, v in attn_processors.items()}
sd_pipe.unet.set_attn_processor(attn_processors)
image = sd_pipe(**inputs).images
assert image.shape == (1, 64, 64, 3)
# run lora xformers attention
attn_processors, _ = create_unet_lora_layers(sd_pipe.unet)
attn_processors = {
k: LoRAXFormersAttnProcessor(hidden_size=v.hidden_size, cross_attention_dim=v.cross_attention_dim)
for k, v in attn_processors.items()
}
attn_processors = {k: v.to("cuda") for k, v in attn_processors.items()}
sd_pipe.unet.set_attn_processor(attn_processors)
image = sd_pipe(**inputs).images
assert image.shape == (1, 64, 64, 3)
enable_full_determinism()
def test_stable_diffusion_no_safety_checker(self): def test_stable_diffusion_no_safety_checker(self):
pipe = StableDiffusionPipeline.from_pretrained( pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None "hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment