"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "67e2f95cc4ff8c25f4d04f8bab46df02216527b2"
Unverified Commit 6110d7c9 authored by takuoko's avatar takuoko Committed by GitHub
Browse files

[Bugfix] fix error of peft lora when xformers enabled (#5697)



* bugfix peft lor

* Apply suggestions from code review

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 65ef7a0c
...@@ -879,6 +879,9 @@ class AttnAddedKVProcessor: ...@@ -879,6 +879,9 @@ class AttnAddedKVProcessor:
scale: float = 1.0, scale: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
...@@ -891,17 +894,17 @@ class AttnAddedKVProcessor: ...@@ -891,17 +894,17 @@ class AttnAddedKVProcessor:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, scale=scale) query = attn.to_q(hidden_states, *args)
query = attn.head_to_batch_dim(query) query = attn.head_to_batch_dim(query)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, scale=scale) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, *args)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, scale=scale) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, *args)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
if not attn.only_cross_attention: if not attn.only_cross_attention:
key = attn.to_k(hidden_states, scale=scale) key = attn.to_k(hidden_states, *args)
value = attn.to_v(hidden_states, scale=scale) value = attn.to_v(hidden_states, *args)
key = attn.head_to_batch_dim(key) key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value) value = attn.head_to_batch_dim(value)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
...@@ -915,7 +918,7 @@ class AttnAddedKVProcessor: ...@@ -915,7 +918,7 @@ class AttnAddedKVProcessor:
hidden_states = attn.batch_to_head_dim(hidden_states) hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states, scale=scale) hidden_states = attn.to_out[0](hidden_states, *args)
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
...@@ -946,6 +949,9 @@ class AttnAddedKVProcessor2_0: ...@@ -946,6 +949,9 @@ class AttnAddedKVProcessor2_0:
scale: float = 1.0, scale: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
...@@ -958,7 +964,7 @@ class AttnAddedKVProcessor2_0: ...@@ -958,7 +964,7 @@ class AttnAddedKVProcessor2_0:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, scale=scale) query = attn.to_q(hidden_states, *args)
query = attn.head_to_batch_dim(query, out_dim=4) query = attn.head_to_batch_dim(query, out_dim=4)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
...@@ -967,8 +973,8 @@ class AttnAddedKVProcessor2_0: ...@@ -967,8 +973,8 @@ class AttnAddedKVProcessor2_0:
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
if not attn.only_cross_attention: if not attn.only_cross_attention:
key = attn.to_k(hidden_states, scale=scale) key = attn.to_k(hidden_states, *args)
value = attn.to_v(hidden_states, scale=scale) value = attn.to_v(hidden_states, *args)
key = attn.head_to_batch_dim(key, out_dim=4) key = attn.head_to_batch_dim(key, out_dim=4)
value = attn.head_to_batch_dim(value, out_dim=4) value = attn.head_to_batch_dim(value, out_dim=4)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
...@@ -985,7 +991,7 @@ class AttnAddedKVProcessor2_0: ...@@ -985,7 +991,7 @@ class AttnAddedKVProcessor2_0:
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1]) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states, scale=scale) hidden_states = attn.to_out[0](hidden_states, *args)
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
...@@ -1177,6 +1183,8 @@ class AttnProcessor2_0: ...@@ -1177,6 +1183,8 @@ class AttnProcessor2_0:
) -> torch.FloatTensor: ) -> torch.FloatTensor:
residual = hidden_states residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None: if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb) hidden_states = attn.spatial_norm(hidden_states, temb)
...@@ -1207,12 +1215,8 @@ class AttnProcessor2_0: ...@@ -1207,12 +1215,8 @@ class AttnProcessor2_0:
elif attn.norm_cross: elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = ( key = attn.to_k(encoder_hidden_states, *args)
attn.to_k(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states, *args)
)
value = (
attn.to_v(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_v(encoder_hidden_states)
)
inner_dim = key.shape[-1] inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads head_dim = inner_dim // attn.heads
...@@ -1232,9 +1236,7 @@ class AttnProcessor2_0: ...@@ -1232,9 +1236,7 @@ class AttnProcessor2_0:
hidden_states = hidden_states.to(query.dtype) hidden_states = hidden_states.to(query.dtype)
# linear proj # linear proj
hidden_states = ( hidden_states = attn.to_out[0](hidden_states, *args)
attn.to_out[0](hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_out[0](hidden_states)
)
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment