Unverified Commit de71fa59 authored by AnyISalIn's avatar AnyISalIn Committed by GitHub
Browse files

fix error of peft lora when xformers enabled (#5506)


Signed-off-by: default avatarAnyISalIn <anyisalin@gmail.com>
parent dcbfe662
...@@ -909,6 +909,8 @@ class XFormersAttnProcessor: ...@@ -909,6 +909,8 @@ class XFormersAttnProcessor:
): ):
residual = hidden_states residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None: if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb) hidden_states = attn.spatial_norm(hidden_states, temb)
...@@ -936,15 +938,15 @@ class XFormersAttnProcessor: ...@@ -936,15 +938,15 @@ class XFormersAttnProcessor:
if attn.group_norm is not None: if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, scale=scale) query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_hidden_states = hidden_states encoder_hidden_states = hidden_states
elif attn.norm_cross: elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, scale=scale) key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, scale=scale) value = attn.to_v(encoder_hidden_states, *args)
query = attn.head_to_batch_dim(query).contiguous() query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous() key = attn.head_to_batch_dim(key).contiguous()
...@@ -957,7 +959,7 @@ class XFormersAttnProcessor: ...@@ -957,7 +959,7 @@ class XFormersAttnProcessor:
hidden_states = attn.batch_to_head_dim(hidden_states) hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states, scale=scale) hidden_states = attn.to_out[0](hidden_states, *args)
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment