Commit 742a8006 authored by Muyang Li's avatar Muyang Li Committed by Zhekai Zhang
Browse files

[major] Fix the tempfile bug in the comfyui

parent 27232e7b
import logging
import os
import tempfile
import folder_paths
from safetensors.torch import save_file
from nunchaku.lora.flux import comfyui2diffusers, convert_to_nunchaku_flux_lowrank_dict, detect_format, xlab2diffusers
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("SVDQuantFluxLoraLoader")
class SVDQuantFluxLoraLoader:
def __init__(self):
......@@ -13,31 +16,20 @@ class SVDQuantFluxLoraLoader:
@classmethod
def INPUT_TYPES(s):
lora_name_list = [
"None",
*folder_paths.get_filename_list("loras"),
"aleksa-codes/flux-ghibsky-illustration/lora.safetensors",
]
lora_name_list = ["None", *folder_paths.get_filename_list("loras")]
base_model_paths = [
"mit-han-lab/svdq-int4-flux.1-dev",
"mit-han-lab/svdq-int4-flux.1-schnell",
"mit-han-lab/svdq-fp4-flux.1-dev",
"mit-han-lab/svdq-fp4-flux.1-schnell",
"mit-han-lab/svdq-int4-flux.1-canny-dev",
"mit-han-lab/svdq-int4-flux.1-depth-dev",
"mit-han-lab/svdq-int4-flux.1-fill-dev",
]
prefix = os.path.join(folder_paths.models_dir, "diffusion_models")
local_base_model_folders = os.listdir(prefix)
local_base_model_folders = sorted(
[
folder
for folder in local_base_model_folders
if not folder.startswith(".") and os.path.isdir(os.path.join(prefix, folder))
]
)
base_model_paths = local_base_model_folders + base_model_paths
prefixes = folder_paths.folder_names_and_paths["diffusion_models"][0]
base_model_paths = set()
for prefix in prefixes:
if os.path.exists(prefix) and os.path.isdir(prefix):
base_model_paths_ = os.listdir(prefix)
base_model_paths_ = [
folder
for folder in base_model_paths_
if not folder.startswith(".") and os.path.isdir(os.path.join(prefix, folder))
]
base_model_paths.update(base_model_paths_)
base_model_paths = sorted(list(base_model_paths))
return {
"required": {
......@@ -63,6 +55,12 @@ class SVDQuantFluxLoraLoader:
"tooltip": "How strongly to modify the diffusion model. This value can be negative.",
},
),
"save_converted_lora": (
["disable", "enable"],
{
"tooltip": "If enabled, the converted LoRA will be saved as a .safetensors file in the save directory of your LoRA file."
},
),
}
}
......@@ -78,7 +76,15 @@ class SVDQuantFluxLoraLoader:
"Currently, only one LoRA nodes can be applied."
)
def load_lora(self, model, lora_name: str, lora_format: str, base_model_name: str, lora_strength: float):
def load_lora(
self,
model,
lora_name: str,
lora_format: str,
base_model_name: str,
lora_strength: float,
save_converted_lora: str,
):
if self.cur_lora_name == lora_name:
if self.cur_lora_name == "None":
pass # Do nothing since the lora is None
......@@ -110,9 +116,22 @@ class SVDQuantFluxLoraLoader:
base_model_path = os.path.join(base_model_name, "transformer_blocks.safetensors")
state_dict = convert_to_nunchaku_flux_lowrank_dict(base_model_path, input_lora)
with tempfile.NamedTemporaryFile(suffix=".safetensors", delete=True) as tmp_file:
save_file(state_dict, tmp_file.name)
model.model.diffusion_model.model.update_lora_params(tmp_file.name)
if save_converted_lora == "enable" and lora_format != "svdquant":
dirname = os.path.dirname(lora_path)
basename = os.path.basename(lora_path)
if "int4" in base_model_path:
precision = "int4"
else:
assert "fp4" in base_model_path
precision = "fp4"
converted_name = f"svdq-{precision}-{basename}"
lora_converted_path = os.path.join(dirname, converted_name)
if not os.path.exists(lora_converted_path):
save_file(state_dict, lora_converted_path)
logger.info(f"Saved converted LoRA to: {lora_converted_path}")
else:
logger.info(f"Converted LoRA already exists at: {lora_converted_path}")
model.model.diffusion_model.model.update_lora_params(state_dict)
else:
model.model.diffusion_model.model.update_lora_params(lora_path)
model.model.diffusion_model.model.set_lora_strength(lora_strength)
......
import os
import comfy.model_patcher
import folder_paths
import torch
......@@ -7,8 +8,10 @@ from comfy.supported_models import Flux, FluxSchnell
from diffusers import FluxTransformer2DModel
from einops import rearrange, repeat
from torch import nn
from nunchaku import NunchakuFluxTransformer2dModel
class ComfyUIFluxForwardWrapper(nn.Module):
def __init__(self, model: NunchakuFluxTransformer2dModel, config):
super(ComfyUIFluxForwardWrapper, self).__init__()
......@@ -59,18 +62,10 @@ class ComfyUIFluxForwardWrapper(nn.Module):
out = rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:, :, :h, :w]
return out
class SVDQuantFluxDiTLoader:
@classmethod
def INPUT_TYPES(s):
model_paths = [
"mit-han-lab/svdq-int4-flux.1-schnell",
"mit-han-lab/svdq-int4-flux.1-dev",
"mit-han-lab/svdq-fp4-flux.1-schnell",
"mit-han-lab/svdq-fp4-flux.1-dev",
"mit-han-lab/svdq-int4-flux.1-canny-dev",
"mit-han-lab/svdq-int4-flux.1-depth-dev",
"mit-han-lab/svdq-int4-flux.1-fill-dev",
]
prefixes = folder_paths.folder_names_and_paths["diffusion_models"][0]
local_folders = set()
for prefix in prefixes:
......@@ -82,8 +77,7 @@ class SVDQuantFluxDiTLoader:
if not folder.startswith(".") and os.path.isdir(os.path.join(prefix, folder))
]
local_folders.update(local_folders_)
local_folders = sorted(list(local_folders))
model_paths = local_folders + model_paths
model_paths = sorted(list(local_folders))
ngpus = torch.cuda.device_count()
return {
"required": {
......@@ -126,35 +120,37 @@ class SVDQuantFluxDiTLoader:
model_path = os.path.join(prefix, model_path)
break
# 验证 device_id 是否有效
# Check if the device_id is valid
if device_id >= torch.cuda.device_count():
raise ValueError(f"Invalid device_id: {device_id}. Only {torch.cuda.device_count()} GPUs available.")
# 获取 ComfyUI 指定 CUDA 设备的显存信息
# Get the GPU properties
gpu_properties = torch.cuda.get_device_properties(device_id)
gpu_memory = gpu_properties.total_memory / (1024 ** 2) # 转换为 MB
gpu_memory = gpu_properties.total_memory / (1024**2) # Convert to MB
gpu_name = gpu_properties.name
print(f"GPU {device_id} ({gpu_name}) 显存: {gpu_memory} MB")
print(f"GPU {device_id} ({gpu_name}) Memory: {gpu_memory} MB")
# 确定 CPU offload 是否启用
# Check if CPU offload needs to be enabled
if cpu_offload == "auto":
if gpu_memory < 14336: # 14GB 阈值
if gpu_memory < 14336: # 14GB threshold
cpu_offload_enabled = True
print("因显存小于14GB,启用 CPU offload")
print("VRAM < 14GiB,enable CPU offload")
else:
cpu_offload_enabled = False
print("显存大于14GB,不启用 CPU offload")
print("VRAM > 14GiB,disable CPU offload")
elif cpu_offload == "enable":
cpu_offload_enabled = True
print("用户启用 CPU offload")
print("Enable CPU offload")
else:
cpu_offload_enabled = False
print("用户禁用 CPU offload")
# 清理 GPU 缓存
# torch.cuda.empty_cache()
transformer = NunchakuFluxTransformer2dModel.from_pretrained(model_path, offload=cpu_offload_enabled)
print("Disable CPU offload")
capability = torch.cuda.get_device_capability(0)
sm = f"{capability[0]}{capability[1]}"
precision = "fp4" if sm == "120" else "int4"
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
model_path, precision=precision, offload=cpu_offload_enabled
)
transformer = transformer.to(device)
dit_config = {
"image_model": "flux",
......
......@@ -45,7 +45,6 @@ class WrappedEmbedding(nn.Module):
class SVDQuantTextEncoderLoader:
@classmethod
def INPUT_TYPES(s):
model_paths = ["mit-han-lab/svdq-flux.1-t5"]
prefixes = folder_paths.folder_names_and_paths["text_encoders"][0]
local_folders = set()
for prefix in prefixes:
......@@ -57,8 +56,7 @@ class SVDQuantTextEncoderLoader:
if not folder.startswith(".") and os.path.isdir(os.path.join(prefix, folder))
]
local_folders.update(local_folders_)
local_folders = sorted(list(local_folders))
model_paths.extend(local_folders)
model_paths = sorted(list(local_folders))
return {
"required": {
"model_type": (["flux"],),
......@@ -68,8 +66,8 @@ class SVDQuantTextEncoderLoader:
"INT",
{"default": 512, "min": 256, "max": 1024, "step": 128, "display": "number", "lazy": True},
),
"t5_precision": (["BF16", "INT4"],),
"int4_model": (model_paths, {"tooltip": "The name of the INT4 model."}),
"use_4bit_t5": (["disable", "enable"],),
"int4_model": (model_paths, {"tooltip": "The name of the 4-bit T5 model."}),
}
}
......@@ -86,7 +84,7 @@ class SVDQuantTextEncoderLoader:
text_encoder1: str,
text_encoder2: str,
t5_min_length: int,
t5_precision: str,
use_4bit_t5: str,
int4_model: str,
):
text_encoder_path1 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder1)
......@@ -105,7 +103,7 @@ class SVDQuantTextEncoderLoader:
if model_type == "flux":
clip.tokenizer.t5xxl.min_length = t5_min_length
if t5_precision == "INT4":
if use_4bit_t5 == "enable":
transformer = clip.cond_stage_model.t5xxl.transformer
param = next(transformer.parameters())
dtype = param.dtype
......
......@@ -3,13 +3,12 @@ import os
import folder_paths
import numpy as np
import torch
from image_gen_aux import DepthPreprocessor
class FluxDepthPreprocessor:
@classmethod
def INPUT_TYPES(s):
model_paths = ["LiheYoung/depth-anything-large-hf"]
model_paths = []
prefix = os.path.join(folder_paths.models_dir, "checkpoints")
local_folders = os.listdir(prefix)
local_folders = sorted(
......@@ -36,9 +35,13 @@ class FluxDepthPreprocessor:
TITLE = "FLUX.1 Depth Preprocessor"
def depth_preprocess(self, image, model_path):
prefix = os.path.join(folder_paths.models_dir, "checkpoints")
if os.path.exists(os.path.join(prefix, model_path)):
model_path = os.path.join(prefix, model_path)
prefixes = folder_paths.folder_names_and_paths["checkpoints"][0]
for prefix in prefixes:
if os.path.exists(os.path.join(prefix, model_path)):
model_path = os.path.join(prefix, model_path)
break
from image_gen_aux import DepthPreprocessor
processor = DepthPreprocessor.from_pretrained(model_path)
np_image = np.asarray(image)
np_result = np.array(processor(np_image)[0].convert("RGB"))
......
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-dev", offload=True
) # set offload to False if you want to disable offloading
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", text_encoder_2=text_encoder_2, transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
pipeline.enable_sequential_cpu_offload() # remove this line if you want to disable the CPU offloading
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=50, guidance_scale=3.5).images[0]
image.save("flux.1-dev.png")
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-schnell", offload=True
) # set offload to False if you want to disable offloading
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
text_encoder_2=text_encoder_2,
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
pipeline.enable_sequential_cpu_offload() # remove this line if you want to disable the CPU offloading
image = pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0
).images[0]
image.save("flux.1-schnell.png")
......@@ -60,12 +60,12 @@ def quantize_t5_encoder(
if isinstance(module, nn.Linear):
if f"{name}.qweight" in state_dict and name.endswith(qlayer_suffix):
print(f"Switching {name} to W4Linear")
qmodule = W4Linear.from_linear(module, group_size=128, init_only=True)
qmodule.qweight.data.copy_(state_dict[f"{name}.qweight"])
if qmodule.bias is not None:
qmodule.bias.data.copy_(state_dict[f"{name}.bias"])
qmodule.scales.data.copy_(state_dict[f"{name}.scales"])
qmodule.scaled_zeros.data.copy_(state_dict[f"{name}.scaled_zeros"])
qmodule = W4Linear.from_linear(module, group_size=128, init_only=False)
# qmodule.qweight.data.copy_(state_dict[f"{name}.qweight"])
# if qmodule.bias is not None:
# qmodule.bias.data.copy_(state_dict[f"{name}.bias"])
# qmodule.scales.data.copy_(state_dict[f"{name}.scales"])
# qmodule.scaled_zeros.data.copy_(state_dict[f"{name}.scaled_zeros"])
# modeling_t5.py: T5DenseGatedActDense needs dtype of weight
qmodule.weight = torch.empty([1], dtype=module.weight.dtype, device=module.weight.device)
......
......@@ -8,7 +8,6 @@ from huggingface_hub import utils
from packaging.version import Version
from torch import nn
from nunchaku.utils import fetch_or_download
from .utils import NunchakuModelLoaderMixin, pad_tensor
from ..._C import QuantizedFluxModel, utils as cutils
from ...utils import load_state_dict_in_safetensors
......@@ -224,13 +223,18 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
new_state_dict[k] = v
self.load_state_dict(new_state_dict, strict=True)
def update_lora_params(self, path: str):
state_dict = load_state_dict_in_safetensors(path)
def update_lora_params(self, path_or_state_dict: str | dict[str, torch.Tensor]):
if isinstance(path_or_state_dict, dict):
state_dict = path_or_state_dict
else:
state_dict = load_state_dict_in_safetensors(path_or_state_dict)
unquantized_loras = {}
for k in state_dict.keys():
if "transformer_blocks" not in k:
unquantized_loras[k] = state_dict[k]
for k in unquantized_loras.keys():
state_dict.pop(k)
self.unquantized_loras = unquantized_loras
if len(unquantized_loras) > 0:
......@@ -239,10 +243,9 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
self.unquantized_state_dict = {k: v.cpu() for k, v in unquantized_state_dict.items()}
self.update_unquantized_lora_params(1)
path = fetch_or_download(path)
block = self.transformer_blocks[0]
assert isinstance(block, NunchakuFluxTransformerBlocks)
block.m.load(path, True)
block.m.loadDict(path_or_state_dict, True)
def set_lora_strength(self, strength: float = 1):
block = self.transformer_blocks[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment