"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "c5d6e0b537bc2097739bd857f8f50ea6499885e4"
Commit b6cc0502 authored by William Berman's avatar William Berman Committed by Will Berman
Browse files

fix simple attention processor encoder hidden states ordering

parent 0cbefefa
...@@ -400,7 +400,6 @@ class AttnAddedKVProcessor: ...@@ -400,7 +400,6 @@ class AttnAddedKVProcessor:
residual = hidden_states residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
...@@ -627,7 +626,6 @@ class SlicedAttnAddedKVProcessor: ...@@ -627,7 +626,6 @@ class SlicedAttnAddedKVProcessor:
def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None): def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
......
...@@ -77,10 +77,10 @@ class UnCLIPTextProjModel(ModelMixin, ConfigMixin): ...@@ -77,10 +77,10 @@ class UnCLIPTextProjModel(ModelMixin, ConfigMixin):
# extra tokens of context that are concatenated to the sequence of outputs from the GLIDE text encoder" # extra tokens of context that are concatenated to the sequence of outputs from the GLIDE text encoder"
clip_extra_context_tokens = self.clip_extra_context_tokens_proj(image_embeddings) clip_extra_context_tokens = self.clip_extra_context_tokens_proj(image_embeddings)
clip_extra_context_tokens = clip_extra_context_tokens.reshape(batch_size, -1, self.clip_extra_context_tokens) clip_extra_context_tokens = clip_extra_context_tokens.reshape(batch_size, -1, self.clip_extra_context_tokens)
clip_extra_context_tokens = clip_extra_context_tokens.permute(0, 2, 1)
text_encoder_hidden_states = self.encoder_hidden_states_proj(text_encoder_hidden_states) text_encoder_hidden_states = self.encoder_hidden_states_proj(text_encoder_hidden_states)
text_encoder_hidden_states = self.text_encoder_hidden_states_norm(text_encoder_hidden_states) text_encoder_hidden_states = self.text_encoder_hidden_states_norm(text_encoder_hidden_states)
text_encoder_hidden_states = text_encoder_hidden_states.permute(0, 2, 1) text_encoder_hidden_states = torch.cat([clip_extra_context_tokens, text_encoder_hidden_states], dim=1)
text_encoder_hidden_states = torch.cat([clip_extra_context_tokens, text_encoder_hidden_states], dim=2)
return text_encoder_hidden_states, additive_clip_time_embeddings return text_encoder_hidden_states, additive_clip_time_embeddings
...@@ -54,6 +54,7 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa ...@@ -54,6 +54,7 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa
"decoder_num_inference_steps", "decoder_num_inference_steps",
"super_res_num_inference_steps", "super_res_num_inference_steps",
] ]
test_xformers_attention = False
@property @property
def text_embedder_hidden_size(self): def text_embedder_hidden_size(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment