Unverified Commit dd3cae33 authored by Asad Memon's avatar Asad Memon Committed by GitHub
Browse files

Pass LoRA rank to LoRALinearLayer (#2191)

parent f73d0b6b
...@@ -296,10 +296,10 @@ class LoRACrossAttnProcessor(nn.Module): ...@@ -296,10 +296,10 @@ class LoRACrossAttnProcessor(nn.Module):
def __init__(self, hidden_size, cross_attention_dim=None, rank=4): def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
super().__init__() super().__init__()
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size) self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size) self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
def __call__( def __call__(
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0 self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
...@@ -408,10 +408,10 @@ class LoRAXFormersCrossAttnProcessor(nn.Module): ...@@ -408,10 +408,10 @@ class LoRAXFormersCrossAttnProcessor(nn.Module):
def __init__(self, hidden_size, cross_attention_dim, rank=4): def __init__(self, hidden_size, cross_attention_dim, rank=4):
super().__init__() super().__init__()
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size) self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size) self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
def __call__( def __call__(
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0 self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment