Unverified Commit b2cfc7a0 authored by Nouamane Tazi's avatar Nouamane Tazi Committed by GitHub
Browse files

Fix slow tests (#689)

* revert using baddbmm in attention
- to fix `test_stable_diffusion_memory_chunking` test

* styling
parent 552b9670
...@@ -274,13 +274,8 @@ class CrossAttention(nn.Module): ...@@ -274,13 +274,8 @@ class CrossAttention(nn.Module):
return self.to_out(hidden_states) return self.to_out(hidden_states)
def _attention(self, query, key, value): def _attention(self, query, key, value):
attention_scores = torch.baddbmm( # TODO: use baddbmm for better performance
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
query,
key.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
attention_probs = attention_scores.softmax(dim=-1) attention_probs = attention_scores.softmax(dim=-1)
# compute attention output # compute attention output
hidden_states = torch.matmul(attention_probs, value) hidden_states = torch.matmul(attention_probs, value)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment