Unverified Commit 8fbf418d authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

feat: support kohya lora and loras with alphas (#459)

* kohya supported

* add a test for the LoRA

* add a gc
parent 46f4251a
import argparse
import logging
import os
import warnings
import torch
from diffusers.loaders import FluxLoraLoaderMixin
......@@ -9,6 +9,52 @@ from safetensors.torch import save_file
from .utils import load_state_dict_in_safetensors
# Get log level from environment variable (default to INFO)
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
# Configure logging
logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
def handle_kohya_lora(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
# first check if the state_dict is in the kohya format
# like: https://civitai.com/models/1118358?modelVersionId=1256866
if any([not k.startswith("lora_transformer_") for k in state_dict.keys()]):
return state_dict
else:
new_state_dict = {}
for k, v in state_dict.items():
new_k = k.replace("lora_transformer_", "transformer.")
new_k = new_k.replace("norm_out_", "norm_out.")
new_k = new_k.replace("time_text_embed_", "time_text_embed.")
new_k = new_k.replace("guidance_embedder_", "guidance_embedder.")
new_k = new_k.replace("text_embedder_", "text_embedder.")
new_k = new_k.replace("timestep_embedder_", "timestep_embedder.")
new_k = new_k.replace("single_transformer_blocks_", "single_transformer_blocks.")
new_k = new_k.replace("_attn_", ".attn.")
new_k = new_k.replace("_norm_linear.", ".norm.linear.")
new_k = new_k.replace("_proj_mlp.", ".proj_mlp.")
new_k = new_k.replace("_proj_out.", ".proj_out.")
new_k = new_k.replace("transformer_blocks_", "transformer_blocks.")
new_k = new_k.replace("to_out_0.", "to_out.0.")
new_k = new_k.replace("_ff_context_net_0_proj.", ".ff_context.net.0.proj.")
new_k = new_k.replace("_ff_context_net_2.", ".ff_context.net.2.")
new_k = new_k.replace("_ff_net_0_proj.", ".ff.net.0.proj.")
new_k = new_k.replace("_ff_net_2.", ".ff.net.2.")
new_k = new_k.replace("_norm1_context_linear.", ".norm1_context.linear.")
new_k = new_k.replace("_norm1_linear.", ".norm1.linear.")
new_k = new_k.replace(".lora_down.", ".lora_A.")
new_k = new_k.replace(".lora_up.", ".lora_B.")
new_state_dict[new_k] = v
return new_state_dict
def to_diffusers(input_lora: str | dict[str, torch.Tensor], output_path: str | None = None) -> dict[str, torch.Tensor]:
if isinstance(input_lora, str):
......@@ -16,6 +62,8 @@ def to_diffusers(input_lora: str | dict[str, torch.Tensor], output_path: str | N
else:
tensors = {k: v for k, v in input_lora.items()}
tensors = handle_kohya_lora(tensors)
### convert the FP8 tensors to BF16
for k, v in tensors.items():
if v.dtype not in [torch.float64, torch.float32, torch.bfloat16, torch.float16]:
......@@ -25,7 +73,14 @@ def to_diffusers(input_lora: str | dict[str, torch.Tensor], output_path: str | N
new_tensors = convert_unet_state_dict_to_peft(new_tensors)
if alphas is not None and len(alphas) > 0:
warnings.warn("Alpha values are not used in the conversion to diffusers format.")
for k, v in alphas.items():
key_A = k.replace(".alpha", ".lora_A.weight")
key_B = k.replace(".alpha", ".lora_B.weight")
assert key_A in new_tensors, f"Key {key_A} not found in new tensors."
assert key_B in new_tensors, f"Key {key_B} not found in new tensors."
rank = new_tensors[key_A].shape[0]
assert new_tensors[key_B].shape[1] == rank, f"Rank mismatch for {key_B}."
new_tensors[key_A] = new_tensors[key_A] * v / rank
if output_path is not None:
output_dir = os.path.dirname(os.path.abspath(output_path))
......
......@@ -12,8 +12,14 @@ from .diffusers_converter import to_diffusers
from .packer import NunchakuWeightPacker
from .utils import is_nunchaku_format, pad
# Get log level from environment variable (default to INFO)
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
# Configure logging
logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# region utilities
......
import pytest
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision, is_turing
from .utils import run_test
......@@ -54,3 +57,38 @@ def test_flux_dev_turbo8_ghibsky_1024x1024():
cache_threshold=0,
expected_lpips=0.310 if get_precision() == "int4" else 0.168,
)
def test_kohya_lora():
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
transformer.update_lora_params("mit-han-lab/nunchaku-test-models/hand_drawn_game.safetensors")
transformer.set_lora_strength(1)
prompt = (
"masterful impressionism oil painting titled 'the violinist', the composition follows the rule of thirds, "
"placing the violinist centrally in the frame. the subject is a young woman with fair skin and light blonde "
"hair is styled in a long, flowing hairstyle with natural waves. she is dressed in an opulent, "
"luxurious silver silk gown with a high waist and intricate gold detailing along the bodice. "
"the gown's texture is smooth and reflective. she holds a violin under her chin, "
"her right hand poised to play, and her left hand supporting the neck of the instrument. "
"she wears a delicate gold necklace with small, sparkling gemstones that catch the light. "
"her beautiful eyes focused on the viewer. the background features an elegantly furnished room "
"with classical late 19th century decor. to the left, there is a large, ornate portrait of "
"a man in a dark suit, set in a gilded frame. below this, a wooden desk with a closed book. "
"to the right, a red upholstered chair with a wooden frame is partially visible. "
"the room is bathed in natural light streaming through a window with red curtains, "
"creating a warm, inviting atmosphere. the lighting highlights the violinist, "
"casting soft shadows that enhance the depth and realism of the scene, highly aesthetic, "
"harmonious colors, impressioniststrokes, "
"<lora:style-impressionist_strokes-flux-by_daalis:1.0> <lora:image_upgrade-flux-by_zeronwo7829:1.0>"
)
image = pipeline(prompt, num_inference_steps=20, guidance_scale=3.5).images[0]
image.save(f"flux.1-dev-{precision}-1.png")
import gc
from types import MethodType
import numpy as np
......@@ -15,6 +16,8 @@ from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_pulid():
gc.collect()
torch.cuda.empty_cache()
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment