Unverified Commit 58c416ab authored by Jorge C. Gomes's avatar Jorge C. Gomes Committed by GitHub
Browse files

Fixes LoRAXFormersCrossAttnProcessor (#2207)

Related to #2124 
The current implementation is throwing a shape mismatch error. Which makes sense, as this line is obviously missing, comparing to XFormersCrossAttnProcessor and LoRACrossAttnProcessor.

I don't have formal tests, but I compared `LoRACrossAttnProcessor` and `LoRAXFormersCrossAttnProcessor` ad-hoc, and they produce the same results with this fix.
parent d46d78c5
......@@ -431,6 +431,7 @@ class LoRAXFormersCrossAttnProcessor(nn.Module):
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment