Unverified Commit 4ae54b37 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[attention] Fix attention (#2656)

* [attention] Fix attention

* fix

* correct
parent fa7a5761
......@@ -271,9 +271,10 @@ class BasicTransformerBlock(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
timestep=None,
attention_mask=None,
cross_attention_kwargs=None,
class_labels=None,
):
......@@ -302,12 +303,14 @@ class BasicTransformerBlock(nn.Module):
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
# TODO (Birch-San): Here we should prepare the encoder_attention mask correctly
# prepare attention mask here
# 2. Cross-Attention
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
......
......@@ -737,7 +737,7 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
# make sure that more than 4 GB is allocated
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes > 4e9
assert mem_bytes > 5e9
assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-2
def test_stable_diffusion_fp16_vs_autocast(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment