Unverified Commit 4ae54b37 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[attention] Fix attention (#2656)

* [attention] Fix attention

* fix

* correct
parent fa7a5761
...@@ -271,9 +271,10 @@ class BasicTransformerBlock(nn.Module): ...@@ -271,9 +271,10 @@ class BasicTransformerBlock(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states,
attention_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None,
timestep=None, timestep=None,
attention_mask=None,
cross_attention_kwargs=None, cross_attention_kwargs=None,
class_labels=None, class_labels=None,
): ):
...@@ -302,12 +303,14 @@ class BasicTransformerBlock(nn.Module): ...@@ -302,12 +303,14 @@ class BasicTransformerBlock(nn.Module):
norm_hidden_states = ( norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
) )
# TODO (Birch-San): Here we should prepare the encoder_attention mask correctly
# prepare attention mask here
# 2. Cross-Attention # 2. Cross-Attention
attn_output = self.attn2( attn_output = self.attn2(
norm_hidden_states, norm_hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask, attention_mask=encoder_attention_mask,
**cross_attention_kwargs, **cross_attention_kwargs,
) )
hidden_states = attn_output + hidden_states hidden_states = attn_output + hidden_states
......
...@@ -737,7 +737,7 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase): ...@@ -737,7 +737,7 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
# make sure that more than 4 GB is allocated # make sure that more than 4 GB is allocated
mem_bytes = torch.cuda.max_memory_allocated() mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes > 4e9 assert mem_bytes > 5e9
assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-2 assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-2
def test_stable_diffusion_fp16_vs_autocast(self): def test_stable_diffusion_fp16_vs_autocast(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment