Unverified Commit 5786b0e2 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

handle dtype xformers attention (#1196)

handle dtype xformers
parent 32b0736d
......@@ -492,6 +492,8 @@ class CrossAttention(nn.Module):
# attention, what we cannot get enough of
if self._use_memory_efficient_attention_xformers:
hidden_states = self._memory_efficient_attention_xformers(query, key, value)
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
hidden_states = hidden_states.to(query.dtype)
else:
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment