Unverified Commit de16f646 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

feat: when using PT 2.0 use LoRAAttnProcessor2_0 for text enc LoRA. (#3691)

parent 017ee160
...@@ -1168,7 +1168,10 @@ class LoraLoaderMixin: ...@@ -1168,7 +1168,10 @@ class LoraLoaderMixin:
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
hidden_size = value_dict["to_k_lora.up.weight"].shape[0] hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
attn_processors[key] = LoRAAttnProcessor( attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
attn_processors[key] = attn_processor_class(
hidden_size=hidden_size, hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
rank=rank, rank=rank,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment