Unverified Commit 7fcce6f3 authored by woctordho's avatar woctordho Committed by GitHub
Browse files

feat: add paddings if LoRA ranks for `q, k, v` are different (#603)

* Add paddings if LoRA ranks for q, k, v are different

* No need to create a list
parent f0c83919
...@@ -29,6 +29,7 @@ import argparse ...@@ -29,6 +29,7 @@ import argparse
import os import os
import torch import torch
import torch.nn.functional as F
from safetensors.torch import save_file from safetensors.torch import save_file
from .diffusers_converter import to_diffusers from .diffusers_converter import to_diffusers
...@@ -120,8 +121,14 @@ def compose_lora( ...@@ -120,8 +121,14 @@ def compose_lora(
.replace(".add_q_proj.", ".add_v_proj.") .replace(".add_q_proj.", ".add_v_proj.")
] ]
assert q_a.shape[0] == k_a.shape[0] == v_a.shape[0] # Add paddings if their ranks are different
assert q_b.shape[1] == k_b.shape[1] == v_b.shape[1] max_rank = max(q_a.shape[0], k_a.shape[0], v_a.shape[0])
q_a = F.pad(q_a, (0, 0, 0, max_rank - q_a.shape[0]))
k_a = F.pad(k_a, (0, 0, 0, max_rank - k_a.shape[0]))
v_a = F.pad(v_a, (0, 0, 0, max_rank - v_a.shape[0]))
q_b = F.pad(q_b, (0, max_rank - q_b.shape[1]))
k_b = F.pad(k_b, (0, max_rank - k_b.shape[1]))
v_b = F.pad(v_b, (0, max_rank - v_b.shape[1]))
if torch.isclose(q_a, k_a).all() and torch.isclose(q_a, v_a).all(): if torch.isclose(q_a, k_a).all() and torch.isclose(q_a, v_a).all():
lora_a = q_a lora_a = q_a
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment