Unverified Commit 6a89a6c9 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Update custom diffusion attn processor (#5663)

update custom diffusion attn processor
parent 9bafef34
...@@ -1361,6 +1361,7 @@ class CustomDiffusionXFormersAttnProcessor(nn.Module): ...@@ -1361,6 +1361,7 @@ class CustomDiffusionXFormersAttnProcessor(nn.Module):
hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[0](hidden_states)
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
return hidden_states return hidden_states
...@@ -1433,8 +1434,11 @@ class CustomDiffusionAttnProcessor2_0(nn.Module): ...@@ -1433,8 +1434,11 @@ class CustomDiffusionAttnProcessor2_0(nn.Module):
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
if self.train_kv: if self.train_kv:
key = self.to_k_custom_diffusion(encoder_hidden_states) key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
value = self.to_v_custom_diffusion(encoder_hidden_states) value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
key = key.to(attn.to_q.weight.dtype)
value = value.to(attn.to_q.weight.dtype)
else: else:
key = attn.to_k(encoder_hidden_states) key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) value = attn.to_v(encoder_hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment