"vscode:/vscode.git/clone" did not exist on "ad2450b129de39a256cb15f14708b10bcb5466b5"
Unverified Commit b2cc2361 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

fix: convert the unet state_dict to PeFT one to fix the removal LoRA (#402)

* start debugging

* fix: convert the unet state_dict to PeFT one to fix the removal LoRA

* style: make linter happy

* fix the lpips
parent fd578ae9
import torch
from diffusers import FluxFillPipeline
from diffusers.utils import load_image
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision
image = load_image("./removal_image.png")
mask = load_image("./removal_mask.png")
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16)
# import ipdb
# ipdb.set_trace()
pipe.load_lora_weights(
"./loras/removalV2.safetensors"
) # Path to your LoRA safetensors, can also be a remote HuggingFace path
pipe.fuse_lora(lora_scale=1)
pipe.enable_model_cpu_offload()
image = pipe(
prompt="",
image=image,
mask_image=mask,
height=720,
width=1280,
guidance_scale=30,
num_inference_steps=20,
max_sequence_length=512,
generator=torch.Generator().manual_seed(42),
).images[0]
image.save(f"flux.1-fill-dev-bf16.png")
import torch
from diffusers import FluxFillPipeline
from diffusers.utils import load_image
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision
image = load_image("./removal_image.png")
mask = load_image("./removal_mask.png")
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-fill-dev")
### LoRA Related Code ###
transformer.update_lora_params(
"loras/removalV2.safetensors"
) # Path to your LoRA safetensors, can also be a remote HuggingFace path
transformer.set_lora_strength(1) # Your LoRA strength here
### End of LoRA Related Code ###
pipe = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
image = pipe(
prompt="",
image=image,
mask_image=mask,
height=720,
width=1280,
guidance_scale=30,
num_inference_steps=20,
max_sequence_length=512,
generator=torch.Generator().manual_seed(42),
).images[0]
image.save(f"flux.1-fill-dev-{precision}.png")
from PIL import Image, ImageChops
# 打开 RGBA 图像
img = Image.open("removal.png").convert("RGBA")
# 拆分成 R, G, B, A 四个通道
r, g, b, a = img.split()
a_inverted = ImageChops.invert(a)
# 合并 R, G, B 成 RGB 图像
rgb_img = Image.merge("RGB", (r, g, b))
# 保存 RGB 和 A 分开的图像
rgb_img.save("removal_image.png")
a_inverted.save("removal_mask.png")
import torch
from safetensors.torch import save_file
from nunchaku.utils import load_state_dict_in_safetensors
if __name__ == "__main__":
sd = load_state_dict_in_safetensors("loras/removalV2.safetensors")
new_sd = {}
for k, v in sd.items():
if ".single_transformer_blocks." in k:
new_sd[k] = v
else:
new_sd[k] = torch.zeros_like(v)
save_file(new_sd, "loras/removalV2-single.safetensors")
......@@ -4,6 +4,7 @@ import warnings
import torch
from diffusers.loaders import FluxLoraLoaderMixin
from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft
from safetensors.torch import save_file
from .utils import load_state_dict_in_safetensors
......@@ -21,6 +22,7 @@ def to_diffusers(input_lora: str | dict[str, torch.Tensor], output_path: str | N
tensors[k] = v.to(torch.bfloat16)
new_tensors, alphas = FluxLoraLoaderMixin.lora_state_dict(tensors, return_alphas=True)
new_tensors = convert_unet_state_dict_to_peft(new_tensors)
if alphas is not None and len(alphas) > 0:
warnings.warn("Alpha values are not used in the conversion to diffusers format.")
......
......@@ -54,7 +54,7 @@ from .utils import already_generate, compute_lpips, offload_pipeline
"waterfall",
23,
0.6,
0.253 if get_precision() == "int4" else 0.226,
0.253 if get_precision() == "int4" else 0.254,
),
],
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment