Commit b6cc0502 authored by William Berman's avatar William Berman Committed by Will Berman
Browse files

fix simple attention processor encoder hidden states ordering

parent 0cbefefa
......@@ -400,7 +400,6 @@ class AttnAddedKVProcessor:
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
......@@ -627,7 +626,6 @@ class SlicedAttnAddedKVProcessor:
def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
......
......@@ -77,10 +77,10 @@ class UnCLIPTextProjModel(ModelMixin, ConfigMixin):
# extra tokens of context that are concatenated to the sequence of outputs from the GLIDE text encoder"
clip_extra_context_tokens = self.clip_extra_context_tokens_proj(image_embeddings)
clip_extra_context_tokens = clip_extra_context_tokens.reshape(batch_size, -1, self.clip_extra_context_tokens)
clip_extra_context_tokens = clip_extra_context_tokens.permute(0, 2, 1)
text_encoder_hidden_states = self.encoder_hidden_states_proj(text_encoder_hidden_states)
text_encoder_hidden_states = self.text_encoder_hidden_states_norm(text_encoder_hidden_states)
text_encoder_hidden_states = text_encoder_hidden_states.permute(0, 2, 1)
text_encoder_hidden_states = torch.cat([clip_extra_context_tokens, text_encoder_hidden_states], dim=2)
text_encoder_hidden_states = torch.cat([clip_extra_context_tokens, text_encoder_hidden_states], dim=1)
return text_encoder_hidden_states, additive_clip_time_embeddings
......@@ -54,6 +54,7 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa
"decoder_num_inference_steps",
"super_res_num_inference_steps",
]
test_xformers_attention = False
@property
def text_embedder_hidden_size(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment