Commit b6cc0502 authored by William Berman's avatar William Berman Committed by Will Berman
Browse files

fix simple attention processor encoder hidden states ordering

parent 0cbefefa
...@@ -400,7 +400,6 @@ class AttnAddedKVProcessor: ...@@ -400,7 +400,6 @@ class AttnAddedKVProcessor:
residual = hidden_states residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
...@@ -627,7 +626,6 @@ class SlicedAttnAddedKVProcessor: ...@@ -627,7 +626,6 @@ class SlicedAttnAddedKVProcessor:
def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None): def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
......
...@@ -77,10 +77,10 @@ class UnCLIPTextProjModel(ModelMixin, ConfigMixin): ...@@ -77,10 +77,10 @@ class UnCLIPTextProjModel(ModelMixin, ConfigMixin):
# extra tokens of context that are concatenated to the sequence of outputs from the GLIDE text encoder" # extra tokens of context that are concatenated to the sequence of outputs from the GLIDE text encoder"
clip_extra_context_tokens = self.clip_extra_context_tokens_proj(image_embeddings) clip_extra_context_tokens = self.clip_extra_context_tokens_proj(image_embeddings)
clip_extra_context_tokens = clip_extra_context_tokens.reshape(batch_size, -1, self.clip_extra_context_tokens) clip_extra_context_tokens = clip_extra_context_tokens.reshape(batch_size, -1, self.clip_extra_context_tokens)
clip_extra_context_tokens = clip_extra_context_tokens.permute(0, 2, 1)
text_encoder_hidden_states = self.encoder_hidden_states_proj(text_encoder_hidden_states) text_encoder_hidden_states = self.encoder_hidden_states_proj(text_encoder_hidden_states)
text_encoder_hidden_states = self.text_encoder_hidden_states_norm(text_encoder_hidden_states) text_encoder_hidden_states = self.text_encoder_hidden_states_norm(text_encoder_hidden_states)
text_encoder_hidden_states = text_encoder_hidden_states.permute(0, 2, 1) text_encoder_hidden_states = torch.cat([clip_extra_context_tokens, text_encoder_hidden_states], dim=1)
text_encoder_hidden_states = torch.cat([clip_extra_context_tokens, text_encoder_hidden_states], dim=2)
return text_encoder_hidden_states, additive_clip_time_embeddings return text_encoder_hidden_states, additive_clip_time_embeddings
...@@ -54,6 +54,7 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa ...@@ -54,6 +54,7 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa
"decoder_num_inference_steps", "decoder_num_inference_steps",
"super_res_num_inference_steps", "super_res_num_inference_steps",
] ]
test_xformers_attention = False
@property @property
def text_embedder_hidden_size(self): def text_embedder_hidden_size(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment